ChatGLM-6B-PT微调

开发环境

ChatGLM2-6B源码

git clone https://github.com/THUDM/ChatGLM2-6B.git

下载模型

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm2-6b-int4

安装依赖

cd ChatGLM2-6B
# 新建模型目录,将模型复制到model目录下
mkdir model
# 安装依赖
pip install -r requirements.txt -i https://mirror.sjtu.edu.cn/pypi/web/simple
# 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/

下载ADGEN数据集

微调前

image

cd ptuning
# 复制训练数据集到ptuning目录中
cp -r /mnt/AdvertiseGen .

# 微调训练
# 训练集目录 ptuning/AdvertiseGen/
# 模型目录 ChatGLM2-6B/model/
# 模型训练输出目录 ptuning/output/
# max_steps 最大训练步数
# save_steps 保存步骤数
# logging_steps 记录日志的频率
# quantization_bit 控制量化的精度
# pre_seq_len 预先设定的序列长度
# learning_rate 使用的学习率
# gradient_accumulation_steps 连续计算梯度的步数
torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/ChatGLM2-6B/model/chatglm2-6b-int4 \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --quantization_bit 4

修改训练步数

torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/ChatGLM2-6B/model/chatglm2-6b-int4 \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 100 \
    --logging_steps 10 \
    --save_steps 50 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --quantization_bit 4

微调后

# 微调后
import os
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModel

model_path = "model/chatglm2-6b-int4"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# 微调后代码
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join("/home/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt--/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

# 使用 Markdown 格式打印模型输出
# from IPython.display import display, Markdown, clear_output

# def display_answer(model, query, history=[]):
#     for response, history in model.stream_chat(
#             tokenizer, query, history=history):
#         clear_output(wait=True)
#         display(Markdown(response))
#     return history

# display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")

prompt = "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"
# 模型输出
current_length = 0
for response, history in model.stream_chat(tokenizer, prompt, history=[]):
    print(response[current_length:], end="", flush=True)
    current_length = len(response)
print("")

image

本文作者:逢生博客

本文链接:https://www.cnblogs.com/wufengsheng/p/17735947.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   逢生博客  阅读(308)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起
  1. 1 晚安 顏人中
  2. 2 出山 花粥 / 王胜娚
  3. 3 我们俩 郭顶
  4. 4 日落大道 梁博
  5. 5 像鱼 王贰浪
  6. 6 把回忆拼好给你 王贰浪
  7. 7 时光背面的我 刘至佳 / 韩瞳
  8. 8 愿你余生漫长 王贰浪
  9. 9 追寻你 王天戈 / 川青
  10. 10 夜空中最亮的星 逃跑计划
  11. 11 孤勇者 陈奕迅
  12. 12 不为谁而作的歌 林俊杰
  13. 13 消愁 毛不易
  14. 14 这一生关于你的风景 隔壁老樊
愿你余生漫长 - 王贰浪
00:00 / 00:00
An audio error has occurred, player will skip forward in 2 seconds.