文字识别网络学习—CRNN+CTC

  OCR(Optical Character Recognition)任务主要是识别出图片中的文字,目前深度学习的方法采用两步来解决这个问题,一是文字检测网络定位文字位置,二是文字识别网络识别出文字。

  关于OCR的综述参考:http://xiaofengshi.com/2019/01/05/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-OCR_Overview/

  CRNN+CTC的文字识别网络是在2015年的论文An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition(2015) 中提出,主要用于序列文本的识别。CRNN的整体流程如下图所示,图片依次经过CNN卷积层,RNN循环层,最后经解码翻译处理得到最后的识别文本。

  对于CRNN文字识别网络的理解主要在于三方面:网络结构,CTC损失函数,数据预处理。CRNN参考代码地址:https://github.com/bgshih/crnn, https://github.com/meijieru/crnn.pytorch  

1. 网络结构

   CRNN的网络结构比较简单,包括VGG11和RNN两部分。采用VGG11进行特征提取,随后采用双层的BiLSTM提取序列信息,其网络结构如图所示:

  训练时模型的计算流程如下:

    1. 经过灰度化和resize后图片的尺寸为(B, 1, 32, 160),图片经过VGG11卷积层得到feature尺寸为(B, 512, 1, 41)

    2. feature经过RNN循环层后网络输出尺寸为(41, B, 4039)。(4039表示字典里共有4038个字符,还有一个字符"_"表示空格)

    3. 尺寸为(41, B, 4039)的输出经过log_softmax后,通过CTC计算loss 

2.CTC损失函数

   CTC(Connectionist Temporal Classification)是在2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks中提出,引入了空白符号,解决了损失计算时,文字标注和网络输出之间的对齐问题。其原理比较复杂,需要专研下,参考:CTC原理理解(转载)  

3.数据预处理

   由于CRNN对于序列行文本效果比较好,所以对于输入的图片除了进行resize和灰度化外,还有两点要注意。

  一是旋转角度过大的序列文本,需要进行一定的旋转,参考下面代码

from scipy.ndimage import filters, interpolation
from numpy import amin, amax


def estimate_skew_angle(raw):
    """
    估计图像文字角度
    因为文字是水平排版,那么此位置图像的行与行之间像素值的方差应该是最大。
    原理大概是这样,先对图像进行二值化处理,然后计算图像每行的均值向量,得到该向量的方差。如果图像文字不存在文字倾斜(假设所有文字朝向一致),那么对应的方差应该是最大,找到方差最大对应的角度,就是文字的倾斜角度。
    本项目中,只取了-15到15度,主要是计算速度的影响,如果不考虑性能,可以计算得更准备。
    """

    def resize_im(im, scale, max_scale=None):
        f = float(scale) / min(im.shape[0], im.shape[1])
        if max_scale != None and f * max(im.shape[0], im.shape[1]) > max_scale:
            f = float(max_scale) / max(im.shape[0], im.shape[1])
        return cv2.resize(im, (0, 0), fx=f, fy=f)

    raw = resize_im(raw, scale=600, max_scale=900)
    image = raw - amin(raw)
    image = image / amax(image)
    m = interpolation.zoom(image, 0.5)
    m = filters.percentile_filter(m, 80, size=(20, 2))
    m = filters.percentile_filter(m, 80, size=(2, 20))
    m = interpolation.zoom(m, 1.0 / 0.5)

    w, h = min(image.shape[1], m.shape[1]), min(image.shape[0], m.shape[0])
    flat = np.clip(image[:h, :w] - m[:h, :w] + 1, 0, 1)
    d0, d1 = flat.shape
    o0, o1 = int(0.1 * d0), int(0.1 * d1)
    flat = amax(flat) - flat
    flat -= amin(flat)
    est = flat[o0:d0 - o0, o1:d1 - o1]
    angles = range(-15, 15)
    estimates = []
    for a in angles:
        roest = interpolation.rotate(est, a, order=0, mode='constant')
        v = np.mean(roest, axis=1)
        v = np.var(v)
        estimates.append((v, a))

    _, a = max(estimates)
    return a

if __name__ == "__main__":
    import os

    src = r"F:\temp"
    for file in os.listdir(src):
        if file.endswith(".jpg"):
            img_path = os.path.join(src, file)
            # img = cv2.imread(img_path, 0)
            img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), flags=0)
            h, w = img.shape
            angle = estimate_skew_angle(img)
            print(file, angle)
            m = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1)
            d = int(np.sqrt(h * h + w * w))
            img2 = cv2.warpAffine(img, m, (w, h))
            cv2.imshow("img2", img2)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
小角度估计

  二是对于长文本需要进行切割成小段,识别后再拼接,如上面网络输出序列为41*B*4039,表示支持的最长文本为41个字符

 

参考:https://zhuanlan.zhihu.com/p/43534801

posted @ 2021-01-06 14:59  silence_cho  阅读(2661)  评论(0编辑  收藏  举报