在ubantu系统中微调ChatGLM-6B
1. ChatGLM-6B
ChatGLM-6B仓库地址:https://github.com/THUDM/ChatGLM-6B
ChatGLM-6B/P-Tuning仓库地址:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning
2、运行环境
在autodl平台购买的RTX 3080x2云主机。
ubantu22.02
conda
3、安装环境
--安装cuda等环境
conda install cudatoolkit
conda install cudnn
--创建虚拟环境
conda create -n tuning-chatglm python=3.10
conda activate tuning-chatglm
-- 拉取代码
git clone https://github.com/THUDM/ChatGLM-6B.git
-- 安装依赖库
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
cd ptuning
# 再次安装依赖,ptuning文档里有说明
pip install rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple/
4、数据准备
-- 数据集格式
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
-- 下载数据集
wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1
将解压后的数据集目录放在ptuning目录下
5、执行训练脚本
--进入到ptuning目录,首先,修改train.sh脚本,主要是修改其中的train_file、validation_file、model_name_or_path、output_dir参数:
train_file:训练数据文件位置
validation_file:验证数据文件位置
model_name_or_path:原始ChatGLM-6B模型文件路径
output_dir:输出模型文件路径
bash ds_train_finetune.sh
--下面是脚本中都内容,需要对应修改
PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /root/autodl-tmp/THUDM/chatglm-6b-int4 \
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
等待...
6、对比测试
微调后
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
MODEL_PATH = "/root/autodl-tmp/THUDM/chatglm-6b-int4"
CHECKPOINT_PATH = "/root/autodl-tmp/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(
MODEL_PATH, config=config, trust_remote_code=True
).cuda()
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
print("用户:你好\n")
response, history = model.chat(tokenizer, "你好", history=[])
print("ChatGLM-6B:\n", response)
print("\n------------------------------------------------\n用户:")
line = input()
while line:
response, history = model.chat(tokenizer, line, history=history)
print("ChatGLM-6B:\n", response)
print("\n------------------------------------------------\n用户:")
line = input()
微调前
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(
"/root/autodl-tmp/THUDM/chatglm-6b-int4", trust_remote_code=True
)
model = (
AutoModel.from_pretrained(
"/root/autodl-tmp/THUDM/chatglm-6b-int4", trust_remote_code=True
)
.half()
.cuda()
)
model = model.eval()
while True:
a = input("请输入您的问题:(输入q以退出)")
if a.strip() == "q":
exit()
response, history = model.chat(
tokenizer, "问题:" + a.strip() + "\n答案:", max_length=256, history=[]
)
print("回答:", response)
分别打开窗口执行。