训练PaddleOCR文本方向分类模型
最近在做一个项目,涉及到扫描答题卡的方向判断。其中一种方法是训练一个文本方向分类模型来判断方向。此处记录一下训练的过程。
环境准备
在一处空闲空间足够大的地方克隆 PaddleOCR 仓库:https://github.com/PaddlePaddle/PaddleOCR
PaddleOCR 仓库体积较大,需下载约 700 MB 数据。
创建一个新的虚拟环境,根据 这个网页 的指导安装 PaddlePaddle 框架。
注意,不要安装 Numpy 2.0 或更新版本,因为 PaddleOCR 可能不兼容。Numpy 1.0 的最新版本是 1.26.4。
然后,进入 PaddleOCR 仓库的根目录,安装必要的依赖:
pip install -r requirements.txt pip install albumentations
训练数据准备
我的数据集包含约 2 万张答题卡扫描图像,数据集按照 9:1 的比例划分为训练集和测试(验证)集。我假设所有图像均为正向,以 0.2 的概率随机将图像旋转 180 度来生成颠倒的图像。在随机旋转的同时生成方向标签。
原始数据集包含多个目录,每个目录中存放一个科目的答题卡图像。我在同级目录下创建了 train
目录和 test
目录,用于保存实际的训练数据。
使用以下 bash 脚本将原始数据集中的图像拷贝到 train
目录和 test
目录:
#!/bin/bash # 定义源目录和目标目录,注意用空格分隔而不是用逗号分隔 source_dirs=("科目1" "科目2" "科目3" "科目4" "科目5") target_train="train" target_test="test" # 遍历所有源目录 for dir in "${source_dirs[@]}"; do # 获取目录中的所有图片文件名 files=$(find "$dir" -maxdepth 1 -type f -name "*.png") # 计算要移动到 train 和 test 目录中的文件数量 total_files=$(echo "$files" | wc -l) train_count=$((total_files * 90 / 100)) test_count=$((total_files - train_count)) # 从所有文件中随机选择 train_count 个文件移动到 train 目录 train_files=$(echo "$files" | shuf -n $train_count) for file in $train_files; do cp "$file" "$target_train" done # 从剩余文件中随机选择 test_count 个文件移动到 test 目录 test_files=$(echo "$files" | shuf -n $test_count) for file in $test_files; do cp "$file" "$target_test" done done
并使用以下 Python 代码完成图像的随即旋转和标签生成:
import os import random from PIL import Image from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import multiprocessing def process_image(filename, directory, label_file): filepath = os.path.join(directory, filename) if filename.endswith('.png'): # 随机决定是否旋转图像 if random.random() > 0.8: # 旋转图像 img = Image.open(filepath) img_rotated = img.rotate(180) img_rotated.save(filepath) # 写入标签(180度旋转) with open(label_file, 'a') as f: f.write(f"{filepath}\t180\n") else: # 不旋转,直接写入标签(0度) with open(label_file, 'a') as f: f.write(f"{filepath}\t0\n") def rotate_and_label_images(directory, label_file): # 获取所有文件 files = [filename for filename in os.listdir(directory) if filename.endswith('.png')] # 使用多线程处理文件 with ThreadPoolExecutor() as executor: list(tqdm(executor.map(lambda x: process_image(x, directory, label_file), files), total=len(files))) # 指定目录和标签文件路径 train_dir = 'train' test_dir = 'test' train_label_file = 'train.txt' test_label_file = 'test.txt' # 处理 train 目录 rotate_and_label_images(train_dir, train_label_file) # 处理 test 目录 rotate_and_label_images(test_dir, test_label_file)
上述的大部分代码是使用通义千问生成的。
依次执行上述两个脚本后,train
目录和 test
目录中 20% 的图像将是颠倒的,同时与它们同级的目录中将会生成 train.txt
文件和 test.txt
文件,存储文件名和标签。两个文件大致遵循以下格式:
test/8D64250B-5B95-11EF-8A7E-3024A9806847.png 0 test/1F83BFEB-5B83-11EF-8A7E-3024A9806847.png 0 test/CC126431-5B92-11EF-8A7E-3024A9806847.png 0 test/87249B09-5B8B-11EF-8A7E-3024A9806847.png 180 test/G669F21C-5B8B-11EF-8A7E-3024A9806847.png 0 test/7AA0DB14-5B8C-11EF-8A7E-3024A9806847.png 180 test/EC082795-5B84-11EF-8A7E-3024A9806847.png 0 test/80B03DC5-3296-11EF-9E58-5CBAEF6F52AE.png 0 test/FEA16C22-24BC-11EF-BDD7-E86A6470B412.png 180 test/B1722A7E-5B8B-11EF-8A7E-3024A9806847.png 180 test/065A748A-5B99-11EF-8A7E-3024A9806847.png 0
需要注意,文件名和标签之间应该用
\t
分隔,而不是空格。否则训练脚本将无法识别。
切换到 PaddleOCR 仓库根目录,新建一个 train_data
目录,然后在其中创建一个名为 cls
的、链接到上述数据集所在目录的软链接。
ln -s /path/to/the/dataset cls
开始训练
我在训练时始终无法成功使 PaddlePaddle 调用 GPU 进行计算。我使用多种方法重装了若干次,并且使用 PaddlePaddle 训练了一个简单的卷积网络,可以正常调用 GPU。但一到 PaddleOCR 的环境中就不行了,表现为占用了一定的显存,但 GPU 完全没有计算,CPU满载。GitHub issue 中没有发现类似的问题。考虑到文本方向分类模型比较轻量,且在我的训练数据上可以快速收敛,因此我使用 CPU 完成了简单的训练。
回到 PaddleOCR 仓库根目录,打开 configs/cls/cls_mv3.yml
,根据需要进行修改。我进行了以下修改:
@@ -1,6 +1,6 @@ Global: - use_gpu: true - epoch_num: 100 + use_gpu: false + epoch_num: 10 log_smooth_window: 20 print_batch_step: 10 save_model_dir: ./output/cls/mv3/ @@ -61,7 +61,7 @@ Train: channel_first: False - ClsLabelEncode: # Class handling label - BaseDataAugmentation: - - RandAugment: + # - RandAugment: - ClsResizeImg: image_shape: [3, 48, 192] - KeepKeys:
上述改动基于 commit 1752c56。
正如上面那段说明所提到的,我在训练时无论如何也调用不了 GPU,于是我将 use_gpu
改为 false
(注意这个选项的值的首字母应该小写),并根据实际的收敛速度减少了 epoch 数量。此外,我没有使用此处描述的数据增强。
执行
python tools/train.py -c configs/cls/cls_mv3.yml
来启动训练。
在经过大约 4 个 epoch 后,模型收敛,精度约为 99.4%。10 个 epoch 跑完之后,模型将被保存到 ``
模型转换
为了执行推理,需要先将训练阶段的模型转换成推理模型。执行
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model=./output/cls/mv3/latest Global.save_inference_dir=./inference/cls/
即可将训练阶段的模型转换成推理模型。转换后,我们可以使用推理模型所在目录的名字来引用这个模型。
推理
直接使用 paddleocr
命令行工具似乎无法完成纯粹的文本方向分类任务,但训练时提供的文本方向分类模型推理脚本在 PaddleOCR 的仓库中,前面已经提到,这仓库很大。为了避免在推理端克隆庞大的 Git 仓库,我们可以简单修改推理脚本的代码,使之仅依赖编译好的 PaddleOCR
pypi 包。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import cv2 import copy import numpy as np import math import time import traceback from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import paddleocr.tools.infer.utility as utility from paddleocr.ppocr.postprocess import build_post_process from paddleocr.ppocr.utils.logging import get_logger from paddleocr.ppocr.utils.utility import get_image_file_list, check_and_read __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) os.environ["FLAGS_allocator_strategy"] = "auto_growth" logger = get_logger() class TextClassifier(object): def __init__(self, args): self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_batch_num = args.cls_batch_num self.cls_thresh = args.cls_thresh postprocess_params = { "name": "ClsPostProcess", "label_list": args.label_list, } self.postprocess_op = build_post_process(postprocess_params) ( self.predictor, self.input_tensor, self.output_tensors, _, ) = utility.create_predictor(args, "cls", logger) self.use_onnx = args.use_onnx def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape h = img.shape[0] w = img.shape[1] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype("float32") if self.cls_image_shape[0] == 1: resized_image = resized_image / 255 resized_image = resized_image[np.newaxis, :] else: resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im def __call__(self, img_list): img_list = copy.deepcopy(img_list) img_num = len(img_list) # Calculate the aspect ratio of all text bars width_list = [] for img in img_list: width_list.append(img.shape[1] / float(img.shape[0])) # Sorting can speed up the cls process indices = np.argsort(np.array(width_list)) cls_res = [["", 0.0]] * img_num batch_num = self.cls_batch_num elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 starttime = time.time() for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]]) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() if self.use_onnx: input_dict = {self.input_tensor.name: norm_img_batch} outputs = self.predictor.run(self.output_tensors, input_dict) prob_out = outputs[0] else: self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] if "180" in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1 ) return img_list, cls_res, elapse def cls(image_dir: str, cls_model_dir: str, use_gpu=False): args = utility.parse_args() args.image_dir = image_dir args.cls_model_dir = cls_model_dir args.use_gpu = use_gpu image_file_list = get_image_file_list(args.image_dir) text_classifier = TextClassifier(args) valid_image_file_list = [] upside_img_list = [] # Process images in batches of 10 batch_size = 10 for i in range(0, len(image_file_list), batch_size): batch_files = image_file_list[i:i + batch_size] batch_imgs = [] for image_file in batch_files: img, flag, _ = check_and_read(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue valid_image_file_list.append(image_file) batch_imgs.append(img) try: batch_imgs, cls_res, predict_time = text_classifier(batch_imgs) except Exception as E: logger.info(traceback.format_exc()) logger.info(E) exit() for ino in range(len(batch_imgs)): # 如果识别结果为180度,需要记录并在日志中输出 if "180" in cls_res[ino][0] and cls_res[ino][1] > args.cls_thresh: upside_img_list.append(batch_files[ino]) logger.info( "The image is upside down: {}, score: {}".format( valid_image_file_list[i + ino], cls_res[ino][1] ) ) return upside_img_list def process_image(path): img = cv2.imread(path) img = cv2.rotate(img, cv2.ROTATE_180) cv2.imwrite(path, img) def check_and_rotate(path): cls_model_dir = "./cls" upside_img_list = cls(path, cls_model_dir, False) print(f"len of upside_img_list: {len(upside_img_list)}") with ThreadPoolExecutor() as executor: list(tqdm(executor.map(process_image, upside_img_list), total=len(upside_img_list)))
结果
- 误判实验:2263 张正向图片,出现了 1 次误判。
- 漏判实验:2263 张颠倒图片,出现了 2263-2056=207 次漏判,漏判率 9.1%
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 解答了困扰我五年的技术问题。时代确实变了!
· PPT革命!DeepSeek+Kimi=N小时工作5分钟完成?
· What?废柴, 还在本地部署DeepSeek吗?Are you kidding?
· 赶AI大潮:在VSCode中使用DeepSeek及近百种模型的极简方法
· DeepSeek企业级部署实战指南:从服务器选型到Dify私有化落地