cnn进行端到端的验证码识别改进

keras_cnn.py 训练及建模

#!/usr/bin/env python
# coding=utf-8

"""
利用keras cnn进行端到端的验证码识别, 简单直接暴力。
迭代100次可以达到95%的准确率,但是很容易过拟合,泛化能力糟糕, 除了增加训练数据还没想到更好的方法.

__autho__: jkmiao
__email__: miao1202@126.com
___date__:2017-02-08

"""
from keras.models import Model
from keras.layers import Dense, Dropout, Flatten, Input, merge
from keras.layers import Convolution2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import os, random
import numpy as np
from keras.models import model_from_json
from util import CharacterTable
from keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
# from keras.utils.visualize_util import plot

def load_data(path='img/clearNoise/'):
    fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('jpg')]
    random.shuffle(fnames)
    data, label = [], []
    for i, fname in enumerate(fnames):
        imgLabel = fname.split('/')[-1].split('_')[0]
        if len(imgLabel)!=6:
            print 'error: ', fname
            continue
        imgM = np.array(Image.open(fname).convert('L'))
        imgM = 1 * (imgM>180)
        data.append(imgM.reshape((60, 200, 1)))
        label.append(imgLabel.lower())
    return np.array(data), label

ctable = CharacterTable()
data, label = load_data()
print data[0].max(), data[0].min()
label_onehot = np.zeros((len(label), 216))
for i, lb in enumerate(label):
    label_onehot[i,:] = ctable.encode(lb)
print data.shape, data[-1].max(), data[-1].min()
print label_onehot.shape


datagen = ImageDataGenerator(shear_range=0.08, zoom_range=0.08, horizontal_flip=False,
                            rotation_range=5, width_shift_range=0.06, height_shift_range=0.06)

datagen.fit(data)

x_train, x_test, y_train, y_test = train_test_split(data, label_onehot, test_size=0.1)

DEBUG = False

# 建模
if DEBUG:
    input_img = Input(shape=(60, 200, 1))

    inner = Convolution2D(16, 7, 7, border_mode='same', activation='relu')(input_img)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    encoder_a = Flatten()(inner)

    inner = Convolution2D(16, 5, 5, border_mode='same', activation='relu')(input_img)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    encoder_b = Flatten()(inner)
    
    inner = Convolution2D(16, 3, 3, border_mode='same', activation='relu')(input_img)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    inner = MaxPooling2D(pool_size=(2,2))(inner)
    inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
    encoder_c = Flatten()(inner)
    
    input = merge([encoder_a, encoder_b, encoder_c], mode='concat', concat_axis=-1)
    drop = Dropout(0.5)(input)
    flatten = Dense(216)(drop)
    flatten = Dropout(0.5)(flatten)
    
    fc1 = Dense(36, activation='softmax')(flatten) 
    fc2 = Dense(36, activation='softmax')(flatten) 
    fc3 = Dense(36, activation='softmax')(flatten) 
    fc4 = Dense(36, activation='softmax')(flatten) 
    fc5 = Dense(36, activation='softmax')(flatten) 
    fc6 = Dense(36, activation='softmax')(flatten) 
    merged = merge([fc1, fc2, fc3, fc4, fc5, fc6], mode='concat', concat_axis=-1)

    model = Model(input=input_img, output=merged)
else:
    model = model_from_json(open('model/ba_cnn_model3.json').read())
    model.load_weights('model/ba_cnn_model3.h5')

# 编译
# model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# plot(model, to_file='model3.png', show_shapes=True)

# 训练

early_stopping = EarlyStopping(monitor='val_loss', patience=5)

model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), samples_per_epoch=len(x_train), nb_epoch=50, validation_data=(x_test, y_test), callbacks=[early_stopping] )

json_string = model.to_json()
with open('./model/ba_cnn_model4.json', 'w') as fw:
    fw.write(json_string)
model.save_weights('./model/ba_cnn_model4.h5')

print 'done saved model cnn3'

# 测试
y_pred = model.predict(x_test, verbose=1)
cnt = 0
for i in range(len(y_pred)):
    guess = ctable.decode(y_pred[i])
    correct = ctable.decode(y_test[i])
    if guess == correct:
        cnt += 1
    if i%10==0:
        print '--'*10, i
        print 'y_pred', guess
        print 'y_test', correct
print cnt/float(len(y_pred))

 

apicode.py  模型使用

#!/usr/bin/env python
# coding=utf-8

from util import CharacterTable
from keras.models import model_from_json
from PIL import Image
import matplotlib.pyplot as plt
import os
import numpy as np
from prepare import clearNoise

def img2vec(fname):
    data = []
    img = clearNoise(fname).convert('L')
    imgM = 1.0 * (np.array(img)>180)
    print imgM.max(), imgM.min()
    data.append(imgM.reshape((60, 200, 1)))
    return np.array(data), imgM

ctable = CharacterTable()

model = model_from_json(open('model/ba_cnn_model4.json').read())
model.load_weights('model/ba_cnn_model4.h5')

def test(path):
    fnames = [ os.path.join(path, fname) for fname in os.listdir(path) ][:50]
    correct = 0
    for idx, fname in enumerate(fnames, 1):
        data, imgM = img2vec(fname)
        y_pred = model.predict(data)
        result = ctable.decode(y_pred[0])
        label = fname.split('/')[-1].split('_')[0]
        if result == label:
            correct += 1
            print 'correct', fname
        else:
            print result, label
        print 'accuracy: ',idx, float(correct)/idx
        print '=='*20
#        plt.subplot(121)
#        plt.imshow(Image.open(fname).convert('L'), plt.cm.gray)
#        plt.title(fname)
#
#        plt.subplot(122)
#        plt.imshow(imgM, plt.cm.gray)
#        plt.title(result)
#        plt.show()

test('test')

 

posted on 2017-03-10 16:18  星空守望者--jkmiao  阅读(2686)  评论(1编辑  收藏  举报