PaddleNlp文本分类操作步骤

一、数据预处理(缺失、重复等,步骤略)

二、标注平台搭建

(1)搭建标注环境(建议使用虚拟环境)

  • python使用3.9.12版本,其他依赖如下(建议完全按照文中python版本及依赖版本安装,否则可能会出现依赖冲突或不兼容问题):
appdirs                        1.4.4
asgiref                        3.7.2
attr                           0.3.1
attrs                          23.1.0
azure-core                     1.29.4
azure-storage-blob             12.18.3
bleach                         5.0.1
boto                           2.49.0
boto3                          1.16.28
botocore                       1.19.28
boxing                         0.1.4
cachetools                     5.3.1
certifi                        2023.7.22
cffi                           1.16.0
charset-normalizer             3.3.0
click                          8.1.7
colorama                       0.4.6
cryptography                   41.0.4
defusedxml                     0.7.1
Django                         3.2.20
django-annoying                0.10.6
django-cors-headers            3.6.0
django-debug-toolbar           3.2.1
django-environ                 0.10.0
django-extensions              3.1.0
django-filter                  2.4.0
django-model-utils             4.1.1
django-ranged-fileresponse     0.1.2
django-rq                      2.5.1
django-storages                1.12.3
django-user-agents             0.4.0
djangorestframework            3.13.1
drf-dynamic-fields             0.3.0
drf-flex-fields                0.9.5
drf-generators                 0.3.0
expiringdict                   1.2.2
google-api-core                2.11.0
google-auth                    2.14.1
google-cloud-appengine-logging 1.1.0
google-cloud-audit-log         0.2.0
google-cloud-core              2.3.2
google-cloud-logging           2.7.2
google-cloud-storage           2.5.0
google-crc32c                  1.5.0
google-resumable-media         2.3.3
googleapis-common-protos       1.56.4
grpc-google-iam-v1             0.12.4
grpcio                         1.59.0
grpcio-status                  1.48.2
htmlmin                        0.1.12
humansignal-drf-yasg           1.21.9
idna                           3.4
ijson                          3.2.3
inflection                     0.5.1
isodate                        0.6.1
jmespath                       0.10.0
joblib                         1.3.2
jsonschema                     3.2.0
label-studio                   1.9.1.post0
label-studio-converter         0.0.57
label-studio-tools             0.0.3
launchdarkly-server-sdk        7.5.0
lockfile                       0.12.2
lxml                           4.9.3
nltk                           3.6.7
numpy                          1.24.3
ordered-set                    4.0.2
packaging                      23.2
pandas                         2.1.1
Pillow                         10.1.0
pip                            23.3
proto-plus                     1.22.3
protobuf                       3.20.3
psycopg2-binary                2.9.6
pyasn1                         0.5.0
pyasn1-modules                 0.3.0
pycparser                      2.21
pydantic                       1.10.13
pyRFC3339                      1.1
pyrsistent                     0.19.3
python-dateutil                2.8.2
python-json-logger             2.0.4
pytz                           2022.7.1
PyYAML                         6.0.1
redis                          3.5.3
regex                          2023.10.3
requests                       2.31.0
rq                             1.10.1
rsa                            4.9
rules                          2.2
s3transfer                     0.3.7
semver                         2.13.0
sentry-sdk                     1.32.0
setuptools                     68.2.2
six                            1.16.0
sqlparse                       0.4.4
tqdm                           4.66.1
typing_extensions              4.8.0
tzdata                         2023.3
ua-parser                      0.18.0
ujson                          5.8.0
uritemplate                    4.1.1
urllib3                        1.26.16
user-agents                    2.2.0
webencodings                   0.5.1
wheel                          0.40.0
xmljson                        0.2.0
  • 启动label_studio
label-studio start

(2)进入浏览器页面,注册登录,开始标注

  • 准备数据,将需要标注的文本数据集按每行一条的格式整理成txt文件,注意不要有多余的空格、回车等特殊符号。
  • 点击Create按钮创建一个新的标注项目,填写项目名称。

  • Data Import页面导入标注数据,并选择List of tasks

  • Labeling Setup中选择Text Classification,表示文本分类任务标注

  • 在设置标签页面,先删除默认标签,再添加需要使用的新标签,添加完成后点击Save保存

  • 进入项目,开始标注,点击待标注数据,选择标签,点击Submit保存标注结果

  • 标注完成后,再次点击数据,标注结果显示位置报JS错误,无法显示,但不影响标注工作,可以点击</>符号查询标注结果

  • 所有数据标注完成后,勾选(全选)已标注文本ID,选择导出的文件类型为JSON,导出数据,重命名为label_studio.json,待使用

三、搭建PaddleNLP环境,转换标注数据

python label_studio.py \
    --label_studio_file ./data/label_studio.json \
    --save_dir ./data \
    --splits 0.8 0.1 0.1 \
    --options ./data/label.txt
label_studio_file: 从label studio导出的数据标注文件。
save_dir: 训练数据的保存目录,默认存储在data目录下。
splits: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照8:1:1的比例将数据划分为训练集、验证集和测试集。
options: 指定分类任务的类别标签。若输入类型为文件,则文件中每行一个标签。
is_shuffle: 是否对数据集进行随机打散,默认为True。
seed: 随机种子,默认为1000.

四、训练

python run_train.py \
--seed 1000 \
--output_dir ./checkpoint/model_best \
--dataset_path ./data/vio_type/ \
--logging_steps 10  \
--save_steps 100  \
--eval_steps 100 \
--max_seq_length 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_train_epochs 20 \
--learning_rate 1e-5 \
--do_train True \
--do_eval True \
--do_export True \
--export_model_dir ./checkpoint/model_best \
--disable_tqdm True \
--metric_for_best_model macro_f1 \
--load_best_model_at_end True \
--save_total_limit 1 \
--save_plm
single_label: 每条样本是否只预测一个标签。默认为False,表示多标签分类。
device: 训练设备,可选择 'cpu''gpu' 其中的一种;默认为 GPU 训练。
logging_steps: 训练过程中日志打印的间隔 steps 数,默认10save_steps: 训练过程中保存模型 checkpoint 的间隔 steps 数,默认100eval_steps: 训练过程中保存模型 checkpoint 的间隔 steps 数,默认100seed:全局随机种子,默认为 42model_name_or_path:进行 few shot 训练使用的预训练模型。默认为 "utc-base", 可选"utc-xbase", "utc-base", "utc-medium", "utc-mini", "utc-micro", "utc-nano", "utc-pico"output_dir:必须,模型训练或压缩后保存的模型目录;默认为 None 。
dataset_path:数据集文件所在目录;默认为 ./data/ 。
train_file:训练集后缀;默认为 train.txt 。
dev_file:开发集后缀;默认为 dev.txt 。
max_seq_len:文本最大切分长度,包括标签的输入超过最大长度时会对输入文本进行自动切分,标签部分不可切分,默认为512per_device_train_batch_size:用于训练的每个 GPU 核心/CPU 的batch大小,默认为8per_device_eval_batch_size:用于评估的每个 GPU 核心/CPU 的batch大小,默认为8num_train_epochs: 训练轮次,使用早停法时可以选择 100;默认为10learning_rate:训练最大学习率,UTC 推荐设置为 1e-5;默认值为3e-5do_train:是否进行微调训练,设置该参数表示进行微调训练,默认不设置。
do_eval:是否进行评估,设置该参数表示进行评估,默认不设置。
do_export:是否进行导出,设置该参数表示进行静态图导出,默认不设置。
export_model_dir:静态图导出地址,默认为None。
overwrite_output_dir: 如果 True,覆盖输出目录的内容。如果 output_dir 指向检查点目录,则使用它继续训练。
disable_tqdm: 是否使用tqdm进度条。
metric_for_best_model:最优模型指标, UTC 推荐设置为 macro_f1,默认为None。
load_best_model_at_end:训练结束后是否加载最优模型,通常与metric_for_best_model配合使用,默认为False。
save_total_limit:如果设置次参数,将限制checkpoint的总数。删除旧的checkpoints 输出目录,默认为None。

五、测试

# -*- coding:utf-8  -*-

from pprint import pprint
from paddlenlp import Taskflow
from conf.db_config import pool as dbpool

schema = ['分类1', '分类2', '分类3']
my_cls = Taskflow("zero_shot_text_classification", model="utc-base", schema=schema, task_path='./checkpoint/model_best/plm', precision="fp8")
pprint(my_cls('中国银河证券澄清合并重组传闻'))
[{
	'predictions': [{
		'label': '分类1',
		'score': 0.9991679502941138
	}],
	'text_a': '中国银河证券澄清合并重组传闻'
}]
posted @   金元宝骑骑  阅读(481)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
点击右上角即可分享
微信分享提示