PaddlePaddle验证码识别初尝试
前言
坑爹金智加了强制验证码登录,让我的小爬虫爬的不是这么快乐了。
人有计策,我有对策,让我们干它!
数据集准备
工欲善其事,必先利其器,这里需要准备验证码图片+正确标签喂给深度学习模型。
手工标注是不可能手工标注的,让我们偷点懒,做一下简单OCR。
CAPTCHA_URL = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'
RAW_SAVE_PATH = 'datasets/'
def save(filepath):
captcha_url = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'
res = requests.get(captcha_url)
with open(filepath, 'wb') as f:
f.write(res.content)
time.sleep(0.1)
def gen_filepath():
for i in range(10 * 10000):
filename = f"{i:08d}.jpg"
filepath = os.path.join(RAW_SAVE_PATH, filename)
if i % 10000 == 0:
print(filename)
if os.path.exists(filepath):
continue
yield filepath
if __name__ == '__main__':
with Pool() as p:
p.map(save, gen_filepath())
这样生成大量带标签的图片,用于接下来的训练。
训练
参考链接如下,我们跟着它修改修改:
https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_ocr/image_ocr.html
网络
class Net(pp.nn.Layer):
def __init__(self, is_infer: bool = False):
super().__init__()
self.is_infer = is_infer
self.conv1 = pp.nn.Conv2D(in_channels=1,
out_channels=CHANNELS_BASE,
kernel_size=3)
self.bn1 = pp.nn.BatchNorm2D(CHANNELS_BASE)
self.conv2 = pp.nn.Conv2D(in_channels=CHANNELS_BASE,
out_channels=CHANNELS_BASE * 2,
kernel_size=3,
stride=2)
self.bn2 = pp.nn.BatchNorm2D(CHANNELS_BASE * 2)
self.conv3 = pp.nn.Conv2D(in_channels=CHANNELS_BASE * 2,
out_channels=CHANNELS_BASE,
kernel_size=1)
self.linear = pp.nn.Linear(in_features=660,
out_features=YZM_LENGTH + 4)
self.lstm = pp.nn.LSTM(input_size=CHANNELS_BASE,
hidden_size=CHANNELS_BASE // 2,
direction='bidirectional',
time_major=True)
self.linear2 = pp.nn.Linear(in_features=CHANNELS_BASE,
out_features=CLASSIFY_NUM)
def forward(self, ipt):
x = self.conv1(ipt)
x = pp.nn.functional.relu(x)
x = self.bn1(x)
x = self.conv2(x)
x = pp.nn.functional.relu(x)
x = self.bn2(x)
x = self.conv3(x)
x = pp.nn.functional.relu(x)
x = pp.tensor.flatten(x, 2)
x = self.linear(x)
x = pp.nn.functional.relu(x)
x = x.transpose([2, 0, 1])
x = self.lstm(x)[0]
x = self.linear2(x)
if self.is_infer:
x = x.transpose([1, 0, 2])
x = pp.nn.functional.softmax(x)
x = pp.argmax(x, axis=-1)
return x
损失函数
class CTCLoss(pp.nn.Layer):
def forward(self, ipt, label):
input_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH + 4, dtype='int64')
label_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH, dtype='int64')
loss = pp.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=len(CHAR_LIST))
return loss
资源下载地址
感谢百度提供的免费V100支持,这也是我选用PaddlePaddle的原因。
代码地址
https://aistudio.baidu.com/aistudio/projectdetail/2060359
数据集地址
https://aistudio.baidu.com/aistudio/datasetdetail/94535
部署
一开始给我坑了,把pdopt文件当成model文件,结果一直加载错误,其实要先转化一下:
inputs = pp.static.InputSpec(shape=[-1, 1, HEIGHT, WIDTH], dtype='float32', name='img')
net = Net(is_infer=True)
model_state_dict = pp.load(PARAMS_PATH)
net.set_state_dict(model_state_dict)
optimizer = pp.optimizer.Adam(learning_rate=0.0001, parameters=net.parameters())
opt_state_dict = pp.load(MODEL_PATH)
optimizer.set_state_dict(opt_state_dict)
net = to_static(net, input_spec=[inputs])
pp.jit.save(net, 'models/inference')
之后就可以愉快的使用了:
def predict_captcha(img):
img = parse_img(img)
img = pre_process(img)
img = np.expand_dims(img, axis=0)
input_names = _predictor.get_input_names()
input_handle = _predictor.get_input_handle(input_names[0])
input_handle.reshape([1, 1, HEIGHT, WIDTH])
input_handle.copy_from_cpu(img)
_predictor.run()
output_names = _predictor.get_output_names()
output_handle = _predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
return label_arr2text(ctc_decode(output_data[0]))
后记
该模型已部署到我自己的库(https://github.com/Licsber/licsber-pypi)中,对于验证码识别,只需要:
from licsber.auth import predict_captcha
对于学校的SSO登录:
from licsber.auth import get_wisedu_session
爬虫又可以快乐起来了(