在ubantu系统中微调ChatGLM-6B

1. ChatGLM-6B

ChatGLM-6B仓库地址:https://github.com/THUDM/ChatGLM-6B
ChatGLM-6B/P-Tuning仓库地址:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

2、运行环境

在autodl平台购买的RTX 3080x2云主机。
ubantu22.02
conda

3、安装环境

--安装cuda等环境
conda install cudatoolkit
conda install cudnn

--创建虚拟环境
conda create -n tuning-chatglm python=3.10
conda activate tuning-chatglm

-- 拉取代码
git clone https://github.com/THUDM/ChatGLM-6B.git

-- 安装依赖库
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/

cd ptuning
# 再次安装依赖,ptuning文档里有说明
pip install rouge_chinese nltk jieba datasets  -i https://pypi.tuna.tsinghua.edu.cn/simple/

4、数据准备

-- 数据集格式
{  
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",  
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"  
}

-- 下载数据集
wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1

将解压后的数据集目录放在ptuning目录下

5、执行训练脚本

--进入到ptuning目录,首先,修改train.sh脚本,主要是修改其中的train_file、validation_file、model_name_or_path、output_dir参数:

train_file:训练数据文件位置
validation_file:验证数据文件位置
model_name_or_path:原始ChatGLM-6B模型文件路径
output_dir:输出模型文件路径

bash ds_train_finetune.sh

--下面是脚本中都内容,需要对应修改
PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /root/autodl-tmp/THUDM/chatglm-6b-int4 \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

等待...

6、对比测试

微调后

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer


MODEL_PATH = "/root/autodl-tmp/THUDM/chatglm-6b-int4"
CHECKPOINT_PATH = "/root/autodl-tmp/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000"

# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(
    MODEL_PATH, config=config, trust_remote_code=True
).cuda()

prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}

for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()


print("用户:你好\n")
response, history = model.chat(tokenizer, "你好", history=[])
print("ChatGLM-6B:\n", response)
print("\n------------------------------------------------\n用户:")

line = input()
while line:
    response, history = model.chat(tokenizer, line, history=history)
    print("ChatGLM-6B:\n", response)
    print("\n------------------------------------------------\n用户:")
    line = input()

微调前

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained(
    "/root/autodl-tmp/THUDM/chatglm-6b-int4", trust_remote_code=True
)
model = (
    AutoModel.from_pretrained(
        "/root/autodl-tmp/THUDM/chatglm-6b-int4", trust_remote_code=True
    )
    .half()
    .cuda()
)
model = model.eval()

while True:
    a = input("请输入您的问题:(输入q以退出)")
    if a.strip() == "q":
        exit()
    response, history = model.chat(
        tokenizer, "问题:" + a.strip() + "\n答案:", max_length=256, history=[]
    )
    print("回答:", response)

分别打开窗口执行。

posted @ 2024-07-17 10:31  雨梦山人  阅读(63)  评论(0编辑  收藏  举报