RALLM 检索增强LLM架构
import copy import os import sys dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) import contextlib import torch.utils.checkpoint from torch.nn import LayerNorm from torch import nn from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modeling_perceive_sampler import BertConfig, BertLMHeadModel from transformers.utils import logging logger = logging.get_logger(__name__) import transformers from transformers import PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM,AutoModel,BertTokenizer,GPT2LMHeadModel,PretrainedConfig,GPT2Model,GPT2Tokenizer,LongformerTokenizer, LongformerModel os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用第一个GPU import argparse import math class RALLM(nn.Module): def __init__(self,args): super(RALLM,self).__init__() self.is_compress = args.is_compress self.use_lora = args.use_lora print('Init LLM ... ') if args.LLM_model == "Baichuan2_13B": self.LLM_model_name = "Baichuan2-13B-Chat" self.LLM_hidden_size = 5120 elif args.LLM_model == "Baichuan2_7B": self.LLM_model_name = "baichuan2_7B" self.LLM_hidden_size = 4096 self.LLM_model = transformers.AutoModelForCausalLM.from_pretrained( self.LLM_model_name, device_map=f"cuda:{args.local_rank}", trust_remote_code=True, torch_dtype=torch.bfloat16, # cache_dir=training_args.cache_dir, ) self.LLM_tokenizer = transformers.AutoTokenizer.from_pretrained( self.LLM_model_name, use_fast=False, trust_remote_code=True, model_max_length=4096, # cache_dir=training_args.cache_dir, ) self.flag_context_start = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device) self.flag_context_end = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device) self.flag_context_start.requires_grad = False self.flag_context_end.requires_grad = False self.device = self.LLM_model.device self.user_token = self.LLM_tokenizer._convert_id_to_token(195) self.assisent_token = self.LLM_tokenizer._convert_id_to_token(196) self.eoa = self.LLM_tokenizer._convert_id_to_token(2) print("user_token:",self.user_token,"assisent_token:",self.assisent_token,"eoa:",self.eoa) print('Done') print('Init context encoder ... ') self.init_context_encoder(args) print('Done') def init_Qformer(self,num_query_token,num_features): self.Qformer = self.init_qformer(num_query_token, num_features,cross_attention_freq=1) self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.Qformer.cls = None @classmethod def init_qformer(cls, num_query_token, vision_width, cross_attention_freq=2, pretrain=True): encoder_config = BertConfig() encoder_config.num_hidden_layers = 2 encoder_config.hidden_size = vision_width encoder_config.encoder_width = vision_width encoder_config.num_attention_heads = vision_width//64 # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) return Qformer def init_context_encoder(self,args): num_query_token = args.query_tokens = 0 if args.encoder == "bert_base": self.context_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese") self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_base_chinese",output_hidden_states=True) num_features = 768 if args.encoder == "bert_large": self.context_tokenizer = AutoTokenizer.from_pretrained("bert_large_chinese",max_length=2000) self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_large_chinese",output_hidden_states=True) num_features = 1024 if args.encoder == "gpt2_xlarge": self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_xlarge") self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_xlarge") num_features = 1600 if args.encoder == "gpt2_large": self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_large") self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_large") num_features = 1280 if args.encoder == "gpt2_large_en": self.context_tokenizer = GPT2Tokenizer.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN") self.context_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.context_encoder = GPT2Model.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN") num_features = 1280 if args.encoder == "longformer": self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer') self.context_encoder = LongformerModel.from_pretrained('longformer') num_features = 768 if args.encoder == "longformer_large": self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer-large') self.context_encoder = LongformerModel.from_pretrained('longformer-large') num_features = 1024 # bert_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese",max_length=2000) # bert_encoder = AutoModelForMaskedLM.from_pretrained("longformer_zh",output_hidden_states=True) #.to(device) self.context_encoder = self.context_encoder.to(self.device) self.context_score = torch.nn.ModuleList([ torch.nn.Linear(num_features, 64), torch.nn.Tanh(), torch.nn.Linear(64, 1), ]) # 768是BERT的隐藏状态维度,1是目标输出维度 self.context2llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size) # 768是BERT的隐藏状态维度,1是目标输出维度 self.llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size) # model.embed2qformer_proj = torch.nn.Linear(num_features, 768) self.ln_features = LayerNorm(num_features) self.init_Qformer(num_query_token,num_features) # del model.internlm_proj # del model.Qformer # torch.cuda.empty_cache() # 释放显存 # if device: # model = self.model.to(self.device) def encode_text(self, text, add_special_tokens=False): input_ids = self.LLM_tokenizer.encode(text) input_ids = torch.LongTensor([input_ids]).to(self.device) if self.use_lora: text_embeds = self.LLM_model.base_model.model.model.embed_tokens(input_ids) else: text_embeds = self.LLM_model.model.embed_tokens(input_ids) return text_embeds def calculate_compressibility(self,x,k=0): return (x * k*(9 / 1000) + 1) * 111.111 / (x + 111.111) # 批量输入句子 def batch_input_sentences(self,sentences): input_ids_list = [self.context_tokenizer.encode(sentence,return_tensors="pt",padding='max_length', max_length=2500, truncation=True) for sentence in sentences] max_length = max(len(input_ids[0]) for input_ids in input_ids_list) input_ids_padded = [torch.cat([input_ids, torch.zeros(1, max_length - input_ids.size(1), dtype=torch.long)], dim=1) for input_ids in input_ids_list] input_ids_tensor = torch.cat(input_ids_padded, dim=0) return input_ids_tensor def encode_context(self, text_list): if text_list is None: return None inputs_LLMs = [] input_atts = [] # print(text_list) for text in text_list: # input_ids = self.context_tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") # 对文本列表进行编码并进行最长的填充 # encoded_ids = self.context_tokenizer(text, padding=True, return_tensors="pt", truncation=True) input_ids = self.batch_input_sentences(text) input_ids =input_ids.to(self.device) # input_ids = encoded_ids.data["input_ids"].to(self.device) # attention_mask = encoded_ids.data["attention_mask"].to(self.device) outputs = self.context_encoder(input_ids,output_hidden_states=True) # 提取最后一层的隐藏状态向量 embedding,last_hidden_state = outputs.hidden_states[0],outputs.hidden_states[-1] #outputs.logits x = last_hidden_state for layer in self.context_score: x = layer(x) output = x # output = self.context_score(last_hidden_state) # 进行线性变换 batch,seq_len,ebd_dim = last_hidden_state.size() # compressibility = -0.0009 * seq_len+1 #压缩率计算长度越低压缩率越低,长度越长,压缩率越高。线性压缩不好 x*f(x) 不是单调递减的 # compressibility = 111.111/(seq_len+111.111) #重新设计非线性压缩 10以下不压缩,0-1000 x*f(x) 递减 if self.is_compress: compressibility = self.calculate_compressibility(seq_len,0) K = math.ceil(seq_len*compressibility) else: K = seq_len # 使用 torch.topk 函数获取 top k 的索引 topk_indices = torch.topk(output, K,dim=1).indices # print(topk_indices) topk_indices, sorted_indices = torch.sort(topk_indices,dim=1) #恢复原文顺序 # print(topk_indices) # 计算 top k 对应的 last_hidden_state topk_selected_last_hidden_state = torch.gather(last_hidden_state, 1, topk_indices.expand(-1, -1, ebd_dim)) # print(last_hidden_state) # print(topk_selected_last_hidden_state) topk_selected_embedding = torch.gather(embedding, 1, topk_indices.expand(-1, -1, ebd_dim)) # bert_text_atts = torch.gather(attention_mask, 1, torch.squeeze(topk_indices, dim=2)) bert_text_embeds = self.ln_features(last_hidden_state) bert_text_atts = torch.ones(bert_text_embeds.size()[:-1],dtype=torch.long).to(self.device) # query_tokens = self.query_tokens.expand(bert_text_atts.shape[0], -1,-1) # query_tokens = topk_selected_embedding query_tokens = topk_selected_last_hidden_state query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=bert_text_embeds, encoder_attention_mask=bert_text_atts, return_dict=True, ) # topk_context_hidden_state = self.context2llm_proj(topk_selected_last_hidden_state) inputs_LLM = self.llm_proj(query_output.last_hidden_state) inputs_LLM = torch.cat([ self.flag_context_start.expand(batch, -1, -1), # topk_context_hidden_state, inputs_LLM, self.flag_context_end.expand(batch, -1, -1) ],dim=1).view(-1, self.LLM_hidden_size) input_att = torch.cat([torch.ones((batch,1)).to(self.device),bert_text_atts,torch.ones((batch,1)).to(self.device)],dim=1).view(-1) # print(inputs_LLM.shape) inputs_LLMs.append(inputs_LLM) input_atts.append(input_att) # context_inputs = torch.stack(inputs_LLMs) return inputs_LLMs,input_atts def wrap_prompt(self, text_embeds, context_embeds=None, history=None, add_special=True): if add_special: if history is None: prompt_segs = [ self.user_token, self.assisent_token ] else: prompt_segs = [self.user_token, self.assisent_token] else: prompt_segs = [self.user_token, self.assisent_token] # used in wrap history prompt_seg_embeds = [] for i, seg in enumerate(prompt_segs): if history is not None: add_special_tokens = False else: add_special_tokens = i == 0 seg_embeds = self.encode_text( seg, add_special_tokens=add_special_tokens) prompt_seg_embeds.append(seg_embeds) if context_embeds is None: context_embeds = text_embeds.new_empty(text_embeds.size(0), 0, text_embeds.size(-1)) else: # 在第一个维度(索引为0)添加一个维度 context_embeds = context_embeds[0].unsqueeze(0) prompt_seg_embeds = [ prompt_seg_embeds[0], text_embeds,context_embeds, prompt_seg_embeds[1] ] prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) if history is not None: prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) return prompt_embeds def generate(self, text, context=None, **kwargs): text = text.replace("<context>","").replace(self.user_token,"").replace(self.assisent_token,"") text_embeds = self.encode_text(text) context_embeds,_ = self.encode_context(context) prompt_embeds = self.wrap_prompt(text_embeds, context_embeds) # out_embeds = self.LLM_model.generate(input_ids=None, # inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) # out_text = self.decode_text(out_embeds) outputs = self.LLM_model.generate(input_ids=None,inputs_embeds=prompt_embeds, generation_config=self.LLM_model.generation_config) response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens=True) return response def chat(self, text, context=None, history=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_context(context) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, history=history) out_embeds = self.internlm_model.generate( inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) # trunc at eoh and eoa clean_out_text_token_ids = self.tokenizer( out_text, return_tensors='pt').input_ids.to(self.device) clean_out_text_embeds = self.internlm_model.model.embed_tokens( clean_out_text_token_ids) clean_prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, add_special=False) cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], dim=1) if history is None: history = [] history.append(cur_history) return out_text, history def align_text(self, samples, has_context=False): ### add eos and eoa 返回<context>后的text text_new = [] if has_context: ### remove the first user to wrap image features text = [ t.split("<context>")[-1] for t in samples["text_input"] ] else: text = [t for t in samples["text_input"]] text = [t + self.eoa for t in text] for i in range(len(text)): temp = text[i] # temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>') # if temp.find(self.eoh) > temp.find(self.eoa): # temp = temp.replace(self.eoa, '', 1) text_new.append(temp) return text_new def prompt_wrap(self, context_embeds,context_atts, prompt_list): batch_size = len(context_embeds) p_before = [prompt.split('<context>')[0] for prompt in prompt_list] p_before_tokens = self.LLM_tokenizer(p_before, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True).to( self.device) if self.use_lora: p_before_embeds = self.LLM_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) else: p_before_embeds = self.LLM_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # wrapped_context_embeds = torch.cat([p_before_embeds, context_embeds], dim=1) # wrapped_context_embeds = torch.cat([p_before_embeds]+context_embeds, dim=1) wrapped_context_embeds = [] wrapped_atts_context = [] wrapped_target = [] for i, (context_embed,context_att) in enumerate(zip(context_embeds,context_atts)): # 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接 concatenated = torch.cat((p_before_embeds[i], context_embed), dim=0) wrapped_context_embeds.append(concatenated) # concatenated_att = torch.cat((torch.ones(p_before_embeds[i].size()[:-1],dtype=torch.long).to(self.device),context_att),dim=0) wrapped_atts_context.append(torch.ones(concatenated.size()[:-1],dtype=torch.long).to(self.device)) # wrapped_atts_context.append(concatenated_att) target = torch.ones(concatenated.size()[:-1], dtype=torch.long) * -100 target[0] = 2 target = target.to(self.device) wrapped_target.append(target) # wrapped_atts_context = torch.ones(wrapped_context_embeds.size()[:-1], # dtype=torch.long).to(self.device) # wrapped_target = torch.ones( # batch_size, wrapped_context_embeds.shape[1], dtype=torch.long).to( # self.device) * -100 return wrapped_context_embeds, wrapped_atts_context, wrapped_target def text2emb(self, text): to_regress_tokens = self.LLM_tokenizer(text, return_tensors="pt", padding="longest", truncation=True, max_length=4096, add_special_tokens=False).to( self.device) targets = self.mask_human_targets(to_regress_tokens.input_ids) targets = targets.to(self.device) return to_regress_tokens, targets def mask_human_targets(self, input_ids, pure=False): target_batch = [] for bs in range(input_ids.shape[0]): cur_idx = 0 ids = input_ids[bs] targets = copy.deepcopy(ids) last_eoa = 0 last_eoh = 0 for i, temp_id in enumerate(ids): if temp_id == 196: #### end of human targets[cur_idx:i+1] = -100 target_batch.append(targets.unsqueeze(0)) target_batch = torch.cat(target_batch, dim=0) target_batch[target_batch==0]=-100 # print(input_ids) # print(target_batch) return target_batch def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, context = None, text_input = None, **kwargs): # samples = kwargs #.get('samples') # has_context = 'context' in samples.keys() if context: has_context = True else: has_context = False samples = {"text_input":text_input,"context":context} ### encode text text = self.align_text(samples=samples, has_context=has_context) #获取<context> 后面的text to_regress_tokens, targets = self.text2emb(text) #返回token和target if self.use_lora: to_regress_embeds = self.LLM_model.base_model.model.model.embed_tokens(to_regress_tokens.input_ids) else: to_regress_embeds = self.LLM_model.model.embed_tokens(to_regress_tokens.input_ids) attention_mask = to_regress_tokens.attention_mask if has_context: prompt = samples["text_input"] ### encode context context = samples["context"] context_embeds,context_atts = self.encode_context(context) context_embeds, atts_context, wrapped_target = self.prompt_wrap( context_embeds,context_atts, prompt) ### combine text and image to_regress_embeds_ = [] attention_mask_ = [] targets_ = [] for i, (tensor0,tensor1,tensor2) in enumerate(zip(to_regress_embeds,attention_mask,targets)): # 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接 to_regress_embed = torch.cat((context_embeds[i], tensor0), dim=0) to_regress_embeds_.append(to_regress_embed) attention_m = torch.cat((atts_context[i], tensor1), dim=0) attention_mask_.append(attention_m) target = torch.cat((wrapped_target[i], tensor2), dim=0) targets_.append(target) # to_regress_embeds = torch.cat([context_embeds, to_regress_embeds], # dim=1) # attention_mask = torch.cat([atts_context, attention_mask], dim=1) # targets = torch.cat([wrapped_target, targets], dim=1) # 确定最大长度 max_len = max(t.size(0) for t in to_regress_embeds_) # 填充张量 padded_to_regress_embeds_ = [] padded_attention_mask_ = [] padded_targets_ = [] for (t,a,l) in zip(to_regress_embeds_,attention_mask_,targets_): if t.size(0) < max_len: # 计算需要填充的长度 padding_size = max_len - t.size(0) # 在序列维度上进行填充 padded_regress = torch.nn.functional.pad(t, (0, 0, 0, padding_size)) padded_attention = torch.nn.functional.pad(a, (0, padding_size), value=0) padded_target = torch.nn.functional.pad(l, (0, padding_size), value=-100) padded_to_regress_embeds_.append(padded_regress) padded_attention_mask_.append(padded_attention) padded_targets_.append(padded_target) else: padded_to_regress_embeds_.append(t) padded_attention_mask_.append(a) padded_targets_.append(l) # 合并张量 to_regress_embeds = torch.stack(padded_to_regress_embeds_) attention_mask = torch.stack(padded_attention_mask_) targets = torch.stack(padded_targets_) outputs = self.LLM_model( inputs_embeds=to_regress_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) return outputs if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--output", default="output", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=32, type=int) parser.add_argument("--load_path", default="/data2/xinyuuliu/InternLM-XComposer/output_rerank", type=str) parser.add_argument("--local_rank", default="0", type=str) args = parser.parse_args() model = RALLM(args) print(model) # model.encode_context("我爱北京天安门") # model.encode_text("我爱北京天安门") # #<ContextHere> # query = "Q:请重复内容:<cont_s><ContextHere><cont_e> \n A:" # context = ["电饭煲不知道怎么选?想要吃一碗香喷喷的米饭,除了米要好之外,还需要一款性能优秀的电饭煲,所以大家在选购电饭煲的时候,一定要多花点心思看看攻略避免踩雷。我前前后后给亲朋好友选购过不下5台电饭煲,也算是积攒了不少选购经验,今天特意总结了一下想分享给大家。1、容量选择市面上电饭煲容量普遍在3L-5L之间,这个范围的容量足够满足绝大部分家庭使用,3L一般可以满足1-3人的家庭,4L一般可以满足2-5人的家庭,5L一般可以满足2-8人的家庭,如果人口超过8人建议直接选择5L以上的容量,使用会更方便。"] # model.interleav_wrap(query,context)
modeling_perceive_sampler.py
""" * Copyright (c) 2023, salesforce.com, inc. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on huggingface code base * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert """ import math from typing import Tuple import torch import torch.utils.checkpoint from torch import Tensor, device from torch import nn from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, ) from transformers.modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.models.bert.configuration_bert import BertConfig from transformers.utils import logging logger = logging.get_logger(__name__) class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size"): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores( self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores( self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view( -1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view( 1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to( dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key) attention_scores = attention_scores / math.sqrt( self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, )) outputs = outputs + (past_key_value, ) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len( heads) self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.layer_num = layer_num if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0): self.crossattention = BertAttention( config, is_cross_attention=self.config.add_cross_attention) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.intermediate_query = BertIntermediate(config) self.output_query = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None) self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: assert ( encoder_hidden_states is not None ), "encoder_hidden_states must be given for cross-attention layers" cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = ( outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output, ) + outputs outputs = outputs + (present_key_value, ) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [BertLayer(config, i) for i in range(config.num_hidden_layers)]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None) next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[ i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, query_length) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1], ) if output_attentions: all_self_attentions = all_self_attentions + ( layer_outputs[1], ) all_cross_attentions = all_cross_attentions + ( layer_outputs[2], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool, has_query: bool = False, ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = (seq_ids[None, None, :].repeat( batch_size, seq_length, 1) <= seq_ids[None, :, None]) # add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[ 1] - causal_mask.shape[1] if has_query: # UniLM style attention mask causal_mask = torch.cat( [ torch.zeros( (batch_size, prefix_seq_len, seq_length), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=1, ) causal_mask = torch.cat( [ torch.ones( (batch_size, causal_mask.shape[1], prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :]) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})" .format(input_shape, attention_mask.shape)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) # use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: assert ( query_embeds is not None ), "You have to specify query_embeds when input_ids is None" # past_key_values_length past_key_values_length = (past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0) query_length = query_embeds.shape[1] if query_embeds is not None else 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, query_embeds=query_embeds, past_key_values_length=past_key_values_length, ) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: attention_mask = torch.ones( ((batch_size, seq_length + past_key_values_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if is_decoder: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_ids.shape, device, is_decoder, has_query=(query_embeds is not None), ) else: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 0].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [ self.invert_attention_mask(mask) for mask in encoder_attention_mask ] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) sequence_output = encoder_outputs[0] pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [ r"position_ids", r"predictions.decoder.bias" ] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=True, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=True, reduction="mean", ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). Returns: Example:: >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> config = BertConfig.from_pretrained("bert-base-cased") >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) if labels is not None: use_cache = False if past_key_values is not None: query_embeds = None outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) sequence_output = outputs[0] if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, : -1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) lm_loss = loss_fct( shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((lm_loss, ) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) query_mask = input_ids.new_ones(query_embeds.shape[:-1]) attention_mask = torch.cat([query_mask, attention_mask], dim=-1) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "query_embeds": query_embeds, "attention_mask": attention_mask, "past_key_values": past, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, } def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += (tuple( past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [ r"position_ids", r"predictions.decoder.bias" ] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores, ) + outputs[2:] return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
dataset_batch.py
# -*- coding: utf-8 -*- from torch.utils.data import Dataset import torch import json import numpy as np import pandas as pd from torch.utils.data import DataLoader import random class QADataset(Dataset): def __init__(self, data_path,train) -> None: super().__init__() self.data = [] data = pd.read_csv(data_path).dropna() print(data.columns) condition = (data['answer'].str.len() <= 1000) & (data['summary'].str.len() <= 500) filtered_data = data[condition] with open("data/corpus.tsv","r") as f_read: corpus = [i.split()[-1] for i in f_read.readlines()] retell_prompts = ["请复述这段被压缩的内容", "复述这段被压缩的内容", "请将被压缩的内容复述出来",] summary_prompts = ["请总结被压缩的信息", "还原被压缩信息的主要内容", "请写出被压缩信息的主要内容", "请对之前压缩的信息进行概括", "请提炼出之前被压缩信息的核心要点", "请归纳一下之前被压缩的内容的主旨"] if train: # 过滤出符合长度条件的文章 # filtered_data1000 = list(filter(self.filter_by_length1000, data["answer"])) for idx in range(5000): # if not line or line == "" or len(line) < 50 or len(line) > 2000: # continue # 随机确定重复次数(1到5次) repeat_count = random.randint(1, 10) flag_context = "<context> "*repeat_count prompt = random.choice(retell_prompts) selected_articles = random.sample(corpus, repeat_count) selected_articles_ = "[SEP]".join(selected_articles) text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{selected_articles_}' test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_} self.data.append( test_data ) # for idx in range(5000): # repeat_count = random.randint(1, 1) # flag_context = "<context> "*repeat_count # selected_articles = random.sample(filtered_data150, repeat_count) # selected_articles_ = " ".join(selected_articles) # text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}' # test_data = {"samples":{"context":selected_articles,"text_input":[text]}} # self.data.append( # test_data # ) for idx,(answer,summary) in enumerate(zip(filtered_data["answer"],filtered_data["summary"])): answer = [answer[:1000]] flag_context = "<context> " prompt = random.choice(summary_prompts) # user_token: <reserved_106> assisent_token: <reserved_107> eoa: </s> text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{summary}' test_data = {"context":answer,"text_input":text,"label":summary} self.data.append( test_data ) # for idx in range(10000): # repeat_count = random.randint(1, 1) # flag_context = "<context> "*repeat_count # selected_articles = random.sample(filtered_data1500, repeat_count) # selected_articles_ = " ".join(selected_articles) # text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}' # test_data = {"samples":{"context":selected_articles,"text_input":[text]}} # self.data.append( # test_data # ) print("data load , size:", len(self.data)) else: for idx in range(100): # if not line or line == "" or len(line) < 50 or len(line) > 2000: # continue # 随机确定重复次数(1到5次) repeat_count = random.randint(3, 5) flag_context = "<context> "*repeat_count prompt = random.choice(retell_prompts) selected_articles = random.sample(corpus, repeat_count) selected_articles_ = "[SEP]".join(selected_articles) text = f'<reserved_106>{prompt}{flag_context}<reserved_107>' test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_} self.data.append( test_data ) # 创建一个函数来过滤文章长度 @staticmethod def filter_by_length150(article): return 180 <= len(article) <= 200 @staticmethod def filter_by_length1000(article): return 50 <= len(article) <= 1000 @staticmethod def filter_by_length1500(article): return 500 <= len(article) <= 1500 def __getitem__(self, index): item_data = self.data[index] return item_data def __len__(self): return len(self.data) if __name__ == "__main__": data_path = "QA_5000_summary.csv" dataset = QADataset(data_path,train=True) # print(dataset[0]) val_params = { "batch_size": 2, "shuffle": False, "num_workers": 0, } def collate_fn(batch): """ 对batch数据进行处理 :param batch: [一个getitem的结果,getitem的结果,getitem的结果] :return: 元组 """ # 初始化一个空字典来存储合并后的结果 merged_dict = {} # 遍历列表中的每个字典 for d in batch: # 遍历每个字典中的键值对 for key, value in d.items(): # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果键不存在于merged_dict中,直接添加到merged_dict中 merged_dict[key] = [value] # 输出合并后的结果 # print(merged_dict) return merged_dict val_loader = DataLoader(dataset, **val_params,collate_fn=collate_fn) for i in val_loader: print(i) break
train_batch.py
# -*- coding: utf-8 -*- import pandas as pd from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModel # from dataset_batch_en import QADataset # from dataset_rerank import QADataset # from dataset_rerank_en_gpt import QADataset from dataset_rerank_en import QADataset from peft import LoraConfig, get_peft_model, TaskType from tqdm import tqdm import torch import os, time, sys import numpy as np from modeling_RALLM import RALLM import argparse import deepspeed from torch.nn.parallel import DataParallel # 设置CUDA设备可见性,例如仅使用第一个GPU # os.environ["CUDA_VISIBLE_DEVICES"] = "3" parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=True, type=bool) parser.add_argument("--compressibility_factor", default=0, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_english_longformer_rerank100k_2", type=str) parser.add_argument("--encoder", default="longformer", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_english_longformer_msmarco2019/checkpoint-87500", type=str) parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-3, type=int) parser.add_argument("--weight_decay", default=0.005, type=int) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=True, type=bool) parser.add_argument("--use_lora_gpt2", default=False, type=bool) parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str) parser.add_argument("--epochs", default=1, type=int) parser.add_argument("--batch_size", default=1, type=int) args = parser.parse_args() def train(epoch, model, loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir): model.train() time1 = time.time() losses = [] train_bar = tqdm(loader,total=len(loader)) for index, data in enumerate(train_bar): optimizer.zero_grad() with torch.autocast(device_type="cuda",dtype=torch.float16): # print(data) outputs = model(model,**data) loss = outputs.loss # 反向传播,计算当前梯度 loss.requires_grad_(True) losses.append(loss.item()) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() if (index+1) % 5000 == 0: model_output_dir_ = os.path.join(model_output_dir,f"epoch{epoch}") model_save_path = os.path.join(model_output_dir_,"index_{}".format(index)) if os.path.exists(model_save_path): pass else: os.makedirs(model_save_path) torch.save(model.state_dict(), os.path.join(model_save_path,"LLM_model_{:.6f}.pth".format(np.mean(losses)))) train_bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch,index,np.mean(losses))) def validate( model, loader): model.eval() predictions = [] actuals = [] with torch.no_grad(): with torch.autocast(device_type="cuda",dtype=torch.float16): for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")): text = data["text_input"] context = data["context"] label = data["label"] print(text) print("len context:",len(context)) for text_,context_ in zip(text,context): preds = model.generate( text=text_,context = [context_] ) print(preds) print(label) predictions.append(preds) actuals.extend(label) return predictions, actuals def main(): epochs = args.epochs batch_size = args.batch_size lr = 1e-5 gradient_accumulation_steps = 16 model_output_dir = args.output # train_path = "qa_3w_summary.csv" train_path = args.train_dataset val_path = args.train_dataset device = torch.device(f"cuda:{args.local_rank}") model = RALLM(args) model = model.to(device) if args.use_lora: print("使用lora训练模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora训练模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["wte","c_attn",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) torch.cuda.empty_cache() # 释放显存 if args.load_path: base_load_path = args.load_path # 列出所有分块模型参数文件的文件名 if base_load_path.endswith(".pth"): state_dict = torch.load(base_load_path,map_location=device) else: file_list = ['pytorch_model.bin'] # 创建一个空的模型状态字典 state_dict = {} # 遍历所有分块文件并加载它们 for file_name in file_list: # 加载单个分块文件的模型参数 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device) # 将加载的模型参数合并到总的模型状态字典中 state_dict.update(part_state_dict) # 将合并后的模型状态字典加载到模型中 print("state_dict:") print(state_dict.keys()) model.load_state_dict(state_dict,strict=False) for param in model.context_encoder.parameters(): param.requires_grad = False # layers_to_modify = [30,31,32,33,34,35] # # Iterate over all named parameters in the model # for name, param in model.context_encoder.named_parameters(): # # Check if the parameter belongs to the specified layers # if any(f"context_encoder.h.{layer}." in name for layer in layers_to_modify): # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer for param in model.ln_features.parameters(): param.requires_grad = True for param in model.Qformer.parameters(): param.requires_grad = True # 遍历每一层并冻结参数 # for param in model.LLM_model.parameters(): # param.requires_grad = False # 冻结除了lora_A和lora_B以外的所有层 # trained = [] # untrained = [] # for name, param in model.LLM_model.named_parameters(): # # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name: # if 'lora_A' in name or 'lora_B' in name: # param.requires_grad = True # trained.append(name) # else: # param.requires_grad = False # untrained.append(name) # Print trainable and non-trainable parameters trainable_params = [] non_trainable_params = [] for name, param in model.named_parameters(): if param.requires_grad: trainable_params.append(name) else: non_trainable_params.append(name) print("Trainable Parameters:") print("\n".join(trainable_params)) print("\nNon-Trainable Parameters:") print("\n".join(non_trainable_params)) # setup peft # peft_config = LoraConfig( # task_type=TaskType.CAUSAL_LM, # target_modules=["q_proj","v_proj"], #W_pack. query_key_value # inference_mode=False, # r=lora_rank, # lora_alpha=lora_alpha, # lora_dropout=0.1 # ) # model = get_peft_model(model, peft_config) # model.is_parallelizable = True # model.model_parallel = True # model.print_trainable_parameters() # 转为半精度 # model.LLM_model = model.LLM_model.half() # model.float() scaler = torch.cuda.amp.GradScaler() def collate_fn(batch): """ 对batch数据进行处理 :param batch: [一个getitem的结果,getitem的结果,getitem的结果] :return: 元组 """ # 初始化一个空字典来存储合并后的结果 merged_dict = {} # 遍历列表中的每个字典 for d in batch: # 遍历每个字典中的键值对 for key, value in d.items(): # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果键不存在于merged_dict中,直接添加到merged_dict中 merged_dict[key] = [value] # 输出合并后的结果 # print(merged_dict) return merged_dict print("Start Load Train Data...") train_params = { "batch_size": batch_size, "shuffle": True, "num_workers": 0, } training_set = QADataset(train_path,train=True) training_loader = DataLoader(training_set, **train_params,collate_fn=collate_fn) print("Start Load Validation Data...") val_params = { "batch_size": batch_size, "shuffle": False, "num_workers": 0, } val_set = QADataset(val_path,train=False) val_loader = DataLoader(val_set, **val_params,collate_fn=collate_fn) # optimizer = torch.optim.AdamW([{'params': model.bert_encoder.parameters(), 'lr': 1e-5}, # {'params': model.Qformer.parameters(), 'lr': 1e-3}, # {'params': model.ln_features.parameters(), 'lr': 1e-3}, # {'params': model.internlm_model.parameters(), 'lr': 1e-5}, # {'params': query_tokens_clone, 'lr': 1e-3}] # # ) optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}]) # device_ids = [1,3,6,7] # model = DataParallel(model, device_ids=device_ids) print("Start Training...") for epoch in range(epochs): # train(epoch, model, training_loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir) # print("Save Model To ", 加) # model.save_pretrained(model_output_dir) # 验证 # print("Start Validation...") with torch.no_grad(): predictions, actuals = validate(model, val_loader) # 验证结果存储 final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals}) val_data_path = os.path.join(model_output_dir, f"predictions_{epoch}.csv") final_df.to_csv(val_data_path) print("Validation Data To ", val_data_path) if __name__ == '__main__': main()
test_chat.py
# -*- coding: utf-8 -*- import pandas as pd from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig from transformers.generation.utils import GenerationConfig from peft import LoraConfig, get_peft_model, TaskType from tqdm import tqdm import torch import os, time, sys import numpy as np from modeling_RALLM import RALLM import argparse from torch import autocast parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=False, type=bool) parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_corpus_2", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_corpus_lora/checkpoint-200004", type=str) parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-5, type=int) parser.add_argument("--weight_decay", default=0.005, type=int) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=False, type=bool) parser.add_argument("--use_lora_gpt2", default=True, type=bool) args = parser.parse_args() def chat( model): while True: context1 = input("输入context:") context2 = input("输入context2:") # context = """NVIDIA的A6000显卡是一款面向专业领域的高性能显卡。关于它的双精度(Double Precision)、单精度(Single Precision)和半精度(Half Precision)的算力,我们可以参考官方提供的规格参数。截至我最后更新的信息(2023年4月),以下是A6000显卡的相关算力数据:双精度(Double Precision): A6000显卡在双精度计算方面的性能通常不如单精度和半精度,因为双精度计算需要更多的计算资源和带宽。具体数值因显卡的不同批次和制造工艺的微小差异可能有所不同。单精度(Single Precision): A6000在单精度计算方面的性能通常很高,适合于大多数图形处理和一些科学计算任务。单精度计算是大多数显卡的主要优势。半精度(Half Precision): 半精度计算主要用于某些机器学习和深度学习应用,能提供更高的吞吐量。A6000显卡在半精度计算方面的性能通常很高。 # """ flag_context = "<context> "*2 text = f'<reserved_106>请复述这段被压缩的内容{flag_context} <reserved_107>' data = {"context":[[context1,context2]],"text_input":text} model.eval() with torch.no_grad(): with autocast(device_type="cuda",dtype=torch.float16): text = data["text_input"] context = data["context"] preds = model.generate( text=text,context = context ) print("输出:",preds) def main(): model = RALLM(args) # 释放不再需要的模型 device = torch.device(f"cuda:{args.local_rank}") model.to(device) if args.use_lora: from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora训练模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["c_attn",], inference_mode=False, r=64, lora_alpha=256, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) base_load_path = "output_qa3w_lora_gpt2_base_corpus" # 列出所有分块模型参数文件的文件名 file_list = ['pytorch_model.bin'] # 创建一个空的模型状态字典 state_dict = {} # 遍历所有分块文件并加载它们 for file_name in file_list: # 加载单个分块文件的模型参数 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=f"cuda:{args.local_rank}") # 将加载的模型参数合并到总的模型状态字典中 state_dict.update(part_state_dict) # 将合并后的模型状态字典加载到模型中 model.load_state_dict(state_dict) model.LLM_model.generation_config = GenerationConfig.from_pretrained(base_load_path) # 加载模型的参数 # load_path = '/data2/xinyuuliu/InternLM-XComposer/output12/epoch9/index_29999/LLM_model_0.109371.pth' # checkpoint = torch.load(load_path,map_location="cuda:0") #,map_location="cuda:3" # # 将参数加载到模型中 # model.load_state_dict(checkpoint) # 转为半精度 # model.LLM_model = model.LLM_model.half() model = model.half() # model.float() chat(model) if __name__ == '__main__': main()
ds_config.json
{ "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 2, "offload_optimizer": { "device": "none", "pin_memory": true }, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 2e8, "contiguous_gradients": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 10, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false }
fine-tune.py
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca. from dataclasses import dataclass, field import json import math import logging import os import random from typing import Dict, Optional, List import torch from torch.utils.data import Dataset from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus import transformers from transformers import Trainer, deepspeed from transformers.trainer_pt_utils import LabelSmoother from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from accelerate.utils import DistributedType from torchvision import transforms from typing import Dict, Optional, Sequence, List from modeling_RALLM import RALLM # from dataset_batch import QADataset from dataset_rerank_en import QADataset # from dataset_rerank_en_gpt import QADataset # from dataset_rerank import QADataset import argparse from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig from torch.optim import AdamW IGNORE_TOKEN_ID = LabelSmoother.ignore_index parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=True, type=bool) parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_corpus_lora2", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_english_longformer_rerank100k/checkpoint-112356", type=str) parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-5, type=float) parser.add_argument("--weight_decay", default=0.01, type=float) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=True, type=bool) parser.add_argument("--use_lora_gpt2", default=False, type=bool) parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.local_rank) @dataclass class TrainingArguments(transformers.TrainingArguments): # cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") local_rank: int = field(default=None) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: """ 对batch数据进行处理 :param batch: [一个getitem的结果,getitem的结果,getitem的结果] :return: 元组 """ # 初始化一个空字典来存储合并后的结果 merged_dict = {} # 遍历列表中的每个字典 for d in instances: # 遍历每个字典中的键值对 for key, value in d.items(): # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果键不存在于merged_dict中,直接添加到merged_dict中 merged_dict[key] = [value] # 输出合并后的结果 # print(merged_dict) return merged_dict def train(): global model train_path = args.train_dataset # train_path = "data/news_summary_30w.csv" # val_path = "QA_5000_summary.csv" device_map = None world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 torch.cuda.device(0) torch.cuda.empty_cache() # 释放显存 # init model and tokenizer model = RALLM(args) # 释放不再需要的模型 device = torch.device(f"cuda:{args.local_rank}") model.to(device) # torch.cuda.device(0) torch.cuda.empty_cache() # 释放显存 if args.use_lora: print("使用lora训练模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora训练模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, # target_modules=["wte","c_attn",], target_modules=["query","key","value","query_global","key_global","value_global"], inference_mode=False, r=128, lora_alpha=512, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) for param in model.context_encoder.parameters(): param.requires_grad = False # layers_to_modify = [27,28,29,30,31,32,33,34, 35] # # Iterate over all named parameters in the model # for name, param in model.context_encoder.named_parameters(): # # Check if the parameter belongs to the specified layers # if any(f"h.{layer}." in name for layer in layers_to_modify): # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer # # if f"ln_f" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # # param.requires_grad = True # or False if you want to freeze the layer # if f"wte" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer # if f"wpe" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer for param in model.ln_features.parameters(): param.requires_grad = True # 遍历每一层并冻结参数 # for param in model.LLM_model.parameters(): # param.requires_grad = False # 冻结除了lora_A和lora_B以外的所有层 trained = [] untrained = [] # for name, param in model.LLM_model.named_parameters(): # # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name: # # if 'lora_A' in name or 'lora_B' in name or "layers.30" in name or "layers.31" in name or "embed_tokens" in name: # if 'lora_A' in name or 'lora_B' in name or "embed_tokens" in name: # param.requires_grad = True # trained.append(name) # else: # param.requires_grad = False # untrained.append(name) # print("可训练的大模型层",trained) # print("不可训练的大模型层",untrained) # Print trainable and non-trainable parameters trainable_params = [] non_trainable_params = [] for name, param in model.named_parameters(): if param.requires_grad: trainable_params.append(name) else: non_trainable_params.append(name) print("Trainable Parameters:") print("\n".join(trainable_params)) print("\nNon-Trainable Parameters:") print("\n".join(non_trainable_params)) if args.load_path: base_load_path = args.load_path # 列出所有分块模型参数文件的文件名 if base_load_path.endswith(".pth"): state_dict = torch.load(base_load_path,map_location=device) else: file_list = ['pytorch_model.bin'] # 创建一个空的模型状态字典 state_dict = {} # 遍历所有分块文件并加载它们 for file_name in file_list: # 加载单个分块文件的模型参数 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device) # 将加载的模型参数合并到总的模型状态字典中 state_dict.update(part_state_dict) # 将合并后的模型状态字典加载到模型中 print("state_dict:") print(state_dict.keys()) model.load_state_dict(state_dict,strict=False) # # 分离 model.Qformer 的参数和其他所有参数 # qformer_params = set(model.Qformer.parameters()) # other_params = [p for p in model.parameters() if p not in qformer_params] # # 创建参数组 # param_groups = [ # {'params': list(qformer_params), 'lr': 1e-3}, # {'params': other_params, 'lr': 1e-5} # ] # 使用参数组创建 AdamW 优化器 # optimizer = AdamW(param_groups) training_set = QADataset(train_path,train=True) # val_set = QADataset(val_path,train=False) print(training_set[0]) # 设置训练参数 training_args = TrainingArguments( local_rank=args.local_rank, output_dir=args.output, # 输出目录 num_train_epochs=args.num_train_epochs, # 训练轮数 per_device_train_batch_size=args.per_device_train_batch_size, # 每个设备的批大小 warmup_steps=500, # 预热步骤 weight_decay=0.01, # 权重衰减 logging_dir='./logs', # 日志目录 deepspeed = "ds_config.json", gradient_accumulation_steps = 1 , save_strategy = "epoch" , learning_rate = 5e-5 , # lr_scheduler_type='linear', # logging_steps= 100, ) data_collator = DataCollatorForSupervisedDataset() # Start trainner trainer = Trainer( model = model, tokenizer = model.LLM_tokenizer, train_dataset=training_set, # eval_dataset=val_set, data_collator=data_collator, args = training_args, # optimizers=(optimizer, None) # 自定义优化器 ) trainer.train() trainer.save_state() trainer.save_model(output_dir=args.output) if __name__ == "__main__": train() #https://arxiv.org/pdf/2102.05951.pdf
fine-tune.sh
hostfile="" # deepspeed --include localhost:1,2,3 --hostfile=$hostfile fine-tune.py \ # --report_to "none" \ # --data_path "/data1/xinyuuliu/qa_data/professional_data/train_二阶段.json" \ # --model_name_or_path "/data1/xinyuuliu/Baichuan2-13B-Chat" \ # --output_dir "output_lora3_1_2" \ # --model_max_length 4000\ # --num_train_epochs 10 \ # --per_device_train_batch_size 4 \ # --gradient_accumulation_steps 1 \ # --save_strategy epoch \ # --learning_rate 2e-4 \ # --lr_scheduler_type constant \ # --adam_beta1 0.9 \ # --adam_beta2 0.98 \ # --adam_epsilon 1e-8 \ # --max_grad_norm 1.0 \ # --weight_decay 1e-4 \ # --warmup_ratio 0.0 \ # --logging_steps 1 \ # --gradient_checkpointing True \ # --deepspeed ds_config.json \ # --bf16 True \ # --tf32 True \ # --use_lora True \ # --load_lora_path /data1/xinyuuliu/Baichuan2-main/fine-tune/output_lora3_1/checkpoint-8260 # --use_NEFT True # --use_frozen True # export CUDA_LAUNCH_BLOCKING=1 # CUDA_VISIBLE_DEVICES=“2,3,4,5,6,7” deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 --hostfile=$hostfile fine-tune.py \ --encoder longformer \ --query_tokens 32 \ --output output_english_longformer_msmarco2019\ --num_train_epochs 20 \ --per_device_train_batch_size 1 \ # --load_path /data2/xinyuuliu/Baichuan2_qformer_bert/output_30w/checkpoint-22488 \
多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。