拆分PPOCRLabel标注的数据集并生成识别数据集

拆分PPOCRLabel标注的数据集并生成识别数据集

说明

关于PPOCRLabel以及本文缘起

PPOCRLabel是OCR领域的标注工具,其本身自带导出识别数据和拆分数据集的功能。其中:

PPOCRLabel本身自带导出识别数据的功能,但是保存检测框图片时会自动旋转图片,具体见其saveRecResult函数实现代码: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/PPOCRLabel.py#L3371

    def saveRecResult(self):
        if {} in [self.PPlabelpath, self.PPlabel, self.fileStatedict]:
            QMessageBox.information(self, "Information", "Check the image first")
            return

        base_dir = os.path.dirname(self.PPlabelpath)
        rec_gt_dir = base_dir + "/rec_gt.txt"
        crop_img_dir = base_dir + "/crop_img/"
        ques_img = []
        if not os.path.exists(crop_img_dir):
            os.mkdir(crop_img_dir)

        with open(rec_gt_dir, "w", encoding="utf-8") as f:
            for key in self.fileStatedict:
                idx = self.getImglabelidx(key)
                try:
                    img_path = os.path.dirname(base_dir) + "/" + key
                    img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
                    for i, label in enumerate(self.PPlabel[idx]):
                        if label["difficult"]:
                            continue
                        img_crop = get_rotate_crop_image(
                            img, np.array(label["points"], np.float32)
                        )
                        img_name = (
                            os.path.splitext(os.path.basename(idx))[0]
                            + "_crop_"
                            + str(i)
                            + ".jpg"
                        )
                        cv2.imencode(".jpg", img_crop)[1].tofile(
                            crop_img_dir + img_name
                        )
                        f.write("crop_img/" + img_name + "\t")
                        f.write(label["transcription"] + "\n")
                except KeyError as e:
                    pass
                except Exception as e:
                    ques_img.append(key)
                    traceback.print_exc()
        if ques_img:
            QMessageBox.information(
                self,
                "Information",
                "The following images can not be saved, please check the image path and labels.\n"
                + "".join(str(i) + "\n" for i in ques_img),
            )
        QMessageBox.information(
            self,
            "Information",
            "Cropped images have been saved in " + str(crop_img_dir),
        )

其中get_rotate_crop_image函数定义: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/libs/utils.py#L137

def get_rotate_crop_image(img, points):
    # Use Green's theory to judge clockwise or counterclockwise
    # author: biyanhua
    d = 0.0
    for index in range(-1, 3):
        d += (
            -0.5
            * (points[index + 1][1] + points[index][1])
            * (points[index + 1][0] - points[index][0])
        )
    if d < 0:  # counterclockwise
        tmp = np.array(points)
        points[1], points[3] = tmp[3], tmp[1]

    try:
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3]),
            )
        )
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2]),
            )
        )
        pts_std = np.float32(
            [
                [0, 0],
                [img_crop_width, 0],
                [img_crop_width, img_crop_height],
                [0, img_crop_height],
            ]
        )
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
            img,
            M,
            (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC,
        )
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img
    except Exception as e:
        print(e)

但是,有的场景是不需要在将裁剪的检测框旋转后再保存的。

另外,PPOCRLabel官方自带脚本可以用于拆分数据集:

python gen_ocr_train_val_test.py --trainValTestRatio 9:1:0 --datasetRootPath dataset/handwritten_digits/images --detRootPath ./train_data/det --recRootPath ./train_data/rec

拆分数据集并生成识别数据集

标注文件格式

假设我们有数据集及其标注文件:

data_dir = "data/"
label_file = 'data/Label_det.txt'

PPOCRLabel的标注文件是 PaddleOCR 文字检测数据格式。

PaddleOCR 中的文本检测算法支持的标注文件格式如下,中间用"\t"分隔:

" 图像文件名                    json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]

这儿假设data_dir加上标注文件中的图像文件名将构成图片的路径。

读取标注文件

读取标注文件内容,并检查下标注文件中的图片是否都存在:

def get_label_lines(data_dir: str, label_file:str):
    with open(label_file, 'r') as f:
        label_lines = f.readlines()

    for line in label_lines:
        img_path, img_label = line.split("\t")
        img_rel_path = os.path.join(data_dir, img_path)
        if not os.path.exists(img_rel_path):
            print(f'{img_rel_path} not exists!')
    
    return label_lines

label_lines = get_label_lines(parent_dir, label_file)

拆分标注数据并保存

拆分label_lines:

from sklearn.model_selection import train_test_split

train_set_label_lines, test_set_label_lines = train_test_split(label_lines, test_size = 0.2, random_state = 42)

保存为具体的数据集(图片和标注文件):

def save_split_data(
    split_label_lines,
    data_dir,
    dest_dir = "dataset",
    split_name = "train",
):
    new_label_lines = []

    first_img_path = split_label_lines[0].split("\t")[0]
    parent_dir_name = os.path.split(os.path.dirname(os.path.join(dest_dir, first_img_path)))[-1]
    rel_dest_img_path = "_".join([parent_dir_name, split_name])
    dest_dir = os.path.join(dest_dir, rel_dest_img_path)
    os.makedirs(dest_dir, exist_ok=True)

    for line in split_label_lines:

        img_path, label_text = line.split("\t")
        label_text = label_text.replace("\n", "")
        assert parent_dir_name == os.path.split(os.path.dirname(os.path.join(dest_dir, img_path)))[-1]
        
        new_label_lines.append("\t".join([os.path.join(rel_dest_img_path, os.path.basename(img_path)), label_text]))

        shutil.copy2(os.path.join(data_dir, img_path), os.path.join(dest_dir, os.path.basename(img_path)))

    label_file_path = os.path.join(dest_dir, "_".join(["Label", parent_dir_name, split_name]) + ".txt")

    with open(label_file_path, "w") as f:
        f.write("\n".join(new_label_lines))
    
    return dest_dir, label_file_path
train_img_dir, train_det_label_file = save_split_data(
    train_set_label_lines,
    data_dir = data_dir,
    dest_dir = "dataset",
    split_name = "train",
)

test_img_dir, test_det_label_file = save_split_data(
    test_set_label_lines,
    data_dir = data_dir,
    dest_dir = "dataset",
    split_name = "test",
)

生成识别图片和标签

def generate_rec_img_label(label_file_path, parent_dir, do_crop = True):
    with open(label_file_path, 'r') as f:
        label_lines = f.readlines()

    rec_label_lines = []

    for line in label_lines:
        img_path, label_text = line.split("\t")
        label_text = label_text.replace("\n", "")
        label_list = json.loads(label_text)
        img = cv2.imread(os.path.join(parent_dir, img_path))
        parent_img_dir = os.path.split(os.path.dirname(img_path))[-1]
        dest_img_dir = dest_img_path = os.path.join(parent_dir, parent_img_dir, "crop_img")
        os.makedirs(dest_img_dir, exist_ok=True)

        
        for idx, label in enumerate(label_list):
            crop_img_name = os.path.splitext(os.path.basename(img_path))[0] + "_crop_" + str(idx) + ".jpg"
            rec_label_lines.append("\t".join([os.path.join(parent_img_dir, "crop_img", crop_img_name), label["transcription"]]))
            dest_img_path = os.path.join(dest_img_dir, crop_img_name)
            if do_crop:
                pt0, pt1, pt2, pt3 = label["points"]
                crop_img = img[pt0[1]:pt2[1], pt0[0]:pt2[0]]
                cv2.imwrite(dest_img_path, crop_img)
            else:
                shutil.copy2(os.path.join(parent_dir, img_path), dest_img_path)
    
    with open(os.path.join(parent_dir, "_".join([os.path.splitext(os.path.basename(label_file_path))[0], "rec"]) + ".txt"), 'w') as f:
        f.write("\n".join(rec_label_lines))
generate_rec_img_label(train_det_label_file, "dataset")
generate_rec_img_label(test_det_label_file, "dataset")
posted @ 2024-10-31 16:07  shizidushu  阅读(160)  评论(0编辑  收藏  举报