念念不忘,必有回响

PaddlePaddle验证码识别初尝试

前言

坑爹金智加了强制验证码登录,让我的小爬虫爬的不是这么快乐了。
人有计策,我有对策,让我们干它!

数据集准备

工欲善其事,必先利其器,这里需要准备验证码图片+正确标签喂给深度学习模型。
手工标注是不可能手工标注的,让我们偷点懒,做一下简单OCR。

CAPTCHA_URL = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'
RAW_SAVE_PATH = 'datasets/'

def save(filepath):
    captcha_url = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'

    res = requests.get(captcha_url)
    with open(filepath, 'wb') as f:
        f.write(res.content)
    time.sleep(0.1)

def gen_filepath():
    for i in range(10 * 10000):
        filename = f"{i:08d}.jpg"
        filepath = os.path.join(RAW_SAVE_PATH, filename)
        if i % 10000 == 0:
            print(filename)

        if os.path.exists(filepath):
            continue

        yield filepath
        
if __name__ == '__main__':
    with Pool() as p:
        p.map(save, gen_filepath())

这样生成大量带标签的图片,用于接下来的训练。

训练

参考链接如下,我们跟着它修改修改:
https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_ocr/image_ocr.html

网络

class Net(pp.nn.Layer):
    def __init__(self, is_infer: bool = False):
        super().__init__()
        self.is_infer = is_infer

        self.conv1 = pp.nn.Conv2D(in_channels=1,
                                  out_channels=CHANNELS_BASE,
                                  kernel_size=3)
        self.bn1 = pp.nn.BatchNorm2D(CHANNELS_BASE)
        self.conv2 = pp.nn.Conv2D(in_channels=CHANNELS_BASE,
                                  out_channels=CHANNELS_BASE * 2,
                                  kernel_size=3,
                                  stride=2)
        self.bn2 = pp.nn.BatchNorm2D(CHANNELS_BASE * 2)
        self.conv3 = pp.nn.Conv2D(in_channels=CHANNELS_BASE * 2,
                                  out_channels=CHANNELS_BASE,
                                  kernel_size=1)
        self.linear = pp.nn.Linear(in_features=660,
                                   out_features=YZM_LENGTH + 4)
        self.lstm = pp.nn.LSTM(input_size=CHANNELS_BASE,
                               hidden_size=CHANNELS_BASE // 2,
                               direction='bidirectional',
                               time_major=True)
        self.linear2 = pp.nn.Linear(in_features=CHANNELS_BASE,
                                    out_features=CLASSIFY_NUM)

    def forward(self, ipt):
        x = self.conv1(ipt)
        x = pp.nn.functional.relu(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = pp.nn.functional.relu(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = pp.nn.functional.relu(x)
        x = pp.tensor.flatten(x, 2)
        x = self.linear(x)
        x = pp.nn.functional.relu(x)
        x = x.transpose([2, 0, 1])
        x = self.lstm(x)[0]
        x = self.linear2(x)

        if self.is_infer:
            x = x.transpose([1, 0, 2])
            x = pp.nn.functional.softmax(x)
            x = pp.argmax(x, axis=-1)
        return x

损失函数

class CTCLoss(pp.nn.Layer):
    def forward(self, ipt, label):
        input_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH + 4, dtype='int64')
        label_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH, dtype='int64')
        loss = pp.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=len(CHAR_LIST))
        return loss

资源下载地址

感谢百度提供的免费V100支持,这也是我选用PaddlePaddle的原因。

代码地址

https://aistudio.baidu.com/aistudio/projectdetail/2060359

数据集地址

https://aistudio.baidu.com/aistudio/datasetdetail/94535

部署

一开始给我坑了,把pdopt文件当成model文件,结果一直加载错误,其实要先转化一下:

    inputs = pp.static.InputSpec(shape=[-1, 1, HEIGHT, WIDTH], dtype='float32', name='img')

    net = Net(is_infer=True)
    model_state_dict = pp.load(PARAMS_PATH)
    net.set_state_dict(model_state_dict)

    optimizer = pp.optimizer.Adam(learning_rate=0.0001, parameters=net.parameters())
    opt_state_dict = pp.load(MODEL_PATH)
    optimizer.set_state_dict(opt_state_dict)

    net = to_static(net, input_spec=[inputs])
    pp.jit.save(net, 'models/inference')

之后就可以愉快的使用了:

def predict_captcha(img):
    img = parse_img(img)
    img = pre_process(img)
    img = np.expand_dims(img, axis=0)

    input_names = _predictor.get_input_names()
    input_handle = _predictor.get_input_handle(input_names[0])
    input_handle.reshape([1, 1, HEIGHT, WIDTH])
    input_handle.copy_from_cpu(img)
    _predictor.run()

    output_names = _predictor.get_output_names()
    output_handle = _predictor.get_output_handle(output_names[0])
    output_data = output_handle.copy_to_cpu()
    return label_arr2text(ctc_decode(output_data[0]))

后记

该模型已部署到我自己的库(https://github.com/Licsber/licsber-pypi)中,对于验证码识别,只需要:

from licsber.auth import predict_captcha

对于学校的SSO登录:

from licsber.auth import get_wisedu_session

爬虫又可以快乐起来了(

posted on 2021-06-13 17:05  licsber  阅读(567)  评论(0编辑  收藏  举报