tf识别单张图片ocr(0到9的识别)

 

 

pip install numpy -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install tensorflow-gpu==1.15.0 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install opencv-python -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

 

 

 

import time

import tensorflow as tf
import cv2 as cv
import numpy as np


def generate_image(a, rnd_size=100):
    image = np.zeros([28, 28], dtype=np.uint8)
    cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, 255, 2, 8)

    for i in range(rnd_size):
        row = np.random.randint(0, 28)
        col = np.random.randint(0, 28)
        image[row, col] = 0

    data = np.reshape(image, [1, 784])
    return image, data / 255


def display_images(images):
    import matplotlib.pyplot as plt
    size = len(images)
    for i in range(size):
        plt.subplot(2, 5, i + 1)
        plt.imshow(images[i])

    plt.show()


def load_data(sess, rnd_size=100, should_display_images=False):
    zero_image, zero = generate_image(0, rnd_size)
    one_image, one = generate_image(1, rnd_size)
    two_image, two = generate_image(2, rnd_size)
    three_image, three = generate_image(3, rnd_size)
    four_image, four = generate_image(4, rnd_size)
    five_image, five = generate_image(5, rnd_size)
    six_image, six = generate_image(6, rnd_size)
    seven_image, seven = generate_image(7, rnd_size)
    eight_image, eight = generate_image(8, rnd_size)
    nine_image, nine = generate_image(9, rnd_size)

    if should_display_images is True:
        display_images(
            [zero_image, one_image, two_image, three_image, four_image, five_image, six_image, seven_image, eight_image,
             nine_image])

    x_features = [zero, one, two, three, four, five, six, seven, eight, nine]
    x_features = np.array(x_features)
    x_features = np.reshape(x_features, (-1, 784))

    y = None
    y_lables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    y = sess.run(tf.one_hot(y_lables, 10))

    return x_features, y


def build_network(nhidden, classes_count):
    x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
    y = tf.placeholder(tf.float32, shape=[None, classes_count], name='y')

    W1 = tf.Variable(tf.random_normal([784, nhidden]))
    b1 = tf.Variable(tf.random_normal([1, nhidden]))
    hidden1 = tf.add(tf.matmul(x, W1), b1)
    hidden1_result = tf.sigmoid(hidden1)

    W2 = tf.Variable(tf.random_normal([nhidden, classes_count]))
    b2 = tf.Variable(tf.random_normal([1, classes_count]))
    out = tf.add(tf.matmul(hidden1_result, W2), b2)
    out_result = tf.sigmoid(out)

    diff = tf.subtract(out_result, y)
    loss = tf.reduce_sum(tf.square(diff))
    train = tf.train.GradientDescentOptimizer(0.1)
    step = train.minimize(loss)

    tf.summary.scalar("loss", loss)

    return x, y, out_result, loss, step


def do_train():
    x, y, y_, loss, step = build_network(10, 10)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        summary_merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter('logs-'+str(time.time()), sess.graph)
        for i in range(800):
            x_features, y_labels = load_data(sess)
            sess.run(step, feed_dict={x: x_features, y: y_labels})
            if (i + 1) % 50 == 0:
                cur_loss, summary_ = sess.run([loss, summary_merged], feed_dict={x: x_features, y: y_labels})
                writer.add_summary(summary_, i)

                pred_ys = sess.run(y_, feed_dict={x: x_features, y: y_labels})
                ys = tf.argmax(pred_ys, 0)
                ys_correct = tf.argmax(y_labels, 0)

                c = tf.equal(ys, ys_correct)
                count = tf.reduce_sum(tf.cast(c, tf.float32))

                r = sess.run(count)
                print(i + 1, ': loss: ', cur_loss, '正确个数:', r)

        print('*************************')
        x_features, y_labels = load_data(sess, 150, should_display_images=True)
        pred_ys = sess.run(y_, feed_dict={x: x_features})
        ys = tf.argmax(pred_ys, 0)
        r = sess.run(ys)
        print('图片识别结果:', r)
        writer.close()


if __name__ == '__main__':
    do_train()

  

 

 

 

50 : loss:  7.3588676 正确个数: 4.0
100 : loss:  6.6502814 正确个数: 5.0
150 : loss:  5.26784 正确个数: 7.0
200 : loss:  4.0591483 正确个数: 9.0
250 : loss:  3.4379258 正确个数: 8.0
300 : loss:  3.114149 正确个数: 8.0
350 : loss:  2.0274947 正确个数: 9.0
400 : loss:  1.4823446 正确个数: 10.0
450 : loss:  1.4051719 正确个数: 10.0
500 : loss:  0.91150457 正确个数: 10.0
550 : loss:  0.7835213 正确个数: 10.0
600 : loss:  0.72512466 正确个数: 10.0
650 : loss:  0.56525075 正确个数: 10.0
700 : loss:  0.4699742 正确个数: 10.0
750 : loss:  0.45453963 正确个数: 10.0
800 : loss:  0.45089394 正确个数: 10.0
*************************
图片识别结果: [0 1 2 3 4 5 6 7 8 9]

 

 

 

 

posted @ 2020-02-12 14:45  McKay  阅读(542)  评论(0编辑  收藏  举报