tf识别单张图片ocr(0到9的识别)

 

 

pip install numpy -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install tensorflow-gpu==1.15.0 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install opencv-python -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

 

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
 
import tensorflow as tf
import cv2 as cv
import numpy as np
 
 
def generate_image(a, rnd_size=100):
    image = np.zeros([28, 28], dtype=np.uint8)
    cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, 255, 2, 8)
 
    for i in range(rnd_size):
        row = np.random.randint(0, 28)
        col = np.random.randint(0, 28)
        image[row, col] = 0
 
    data = np.reshape(image, [1, 784])
    return image, data / 255
 
 
def display_images(images):
    import matplotlib.pyplot as plt
    size = len(images)
    for i in range(size):
        plt.subplot(2, 5, i + 1)
        plt.imshow(images[i])
 
    plt.show()
 
 
def load_data(sess, rnd_size=100, should_display_images=False):
    zero_image, zero = generate_image(0, rnd_size)
    one_image, one = generate_image(1, rnd_size)
    two_image, two = generate_image(2, rnd_size)
    three_image, three = generate_image(3, rnd_size)
    four_image, four = generate_image(4, rnd_size)
    five_image, five = generate_image(5, rnd_size)
    six_image, six = generate_image(6, rnd_size)
    seven_image, seven = generate_image(7, rnd_size)
    eight_image, eight = generate_image(8, rnd_size)
    nine_image, nine = generate_image(9, rnd_size)
 
    if should_display_images is True:
        display_images(
            [zero_image, one_image, two_image, three_image, four_image, five_image, six_image, seven_image, eight_image,
             nine_image])
 
    x_features = [zero, one, two, three, four, five, six, seven, eight, nine]
    x_features = np.array(x_features)
    x_features = np.reshape(x_features, (-1, 784))
 
    y = None
    y_lables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    y = sess.run(tf.one_hot(y_lables, 10))
 
    return x_features, y
 
 
def build_network(nhidden, classes_count):
    x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
    y = tf.placeholder(tf.float32, shape=[None, classes_count], name='y')
 
    W1 = tf.Variable(tf.random_normal([784, nhidden]))
    b1 = tf.Variable(tf.random_normal([1, nhidden]))
    hidden1 = tf.add(tf.matmul(x, W1), b1)
    hidden1_result = tf.sigmoid(hidden1)
 
    W2 = tf.Variable(tf.random_normal([nhidden, classes_count]))
    b2 = tf.Variable(tf.random_normal([1, classes_count]))
    out = tf.add(tf.matmul(hidden1_result, W2), b2)
    out_result = tf.sigmoid(out)
 
    diff = tf.subtract(out_result, y)
    loss = tf.reduce_sum(tf.square(diff))
    train = tf.train.GradientDescentOptimizer(0.1)
    step = train.minimize(loss)
 
    tf.summary.scalar("loss", loss)
 
    return x, y, out_result, loss, step
 
 
def do_train():
    x, y, y_, loss, step = build_network(10, 10)
 
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        summary_merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter('logs-'+str(time.time()), sess.graph)
        for i in range(800):
            x_features, y_labels = load_data(sess)
            sess.run(step, feed_dict={x: x_features, y: y_labels})
            if (i + 1) % 50 == 0:
                cur_loss, summary_ = sess.run([loss, summary_merged], feed_dict={x: x_features, y: y_labels})
                writer.add_summary(summary_, i)
 
                pred_ys = sess.run(y_, feed_dict={x: x_features, y: y_labels})
                ys = tf.argmax(pred_ys, 0)
                ys_correct = tf.argmax(y_labels, 0)
 
                c = tf.equal(ys, ys_correct)
                count = tf.reduce_sum(tf.cast(c, tf.float32))
 
                r = sess.run(count)
                print(i + 1, ': loss: ', cur_loss, '正确个数:', r)
 
        print('*************************')
        x_features, y_labels = load_data(sess, 150, should_display_images=True)
        pred_ys = sess.run(y_, feed_dict={x: x_features})
        ys = tf.argmax(pred_ys, 0)
        r = sess.run(ys)
        print('图片识别结果:', r)
        writer.close()
 
 
if __name__ == '__main__':
    do_train()

  

 

 

 

复制代码
50 : loss:  7.3588676 正确个数: 4.0
100 : loss:  6.6502814 正确个数: 5.0
150 : loss:  5.26784 正确个数: 7.0
200 : loss:  4.0591483 正确个数: 9.0
250 : loss:  3.4379258 正确个数: 8.0
300 : loss:  3.114149 正确个数: 8.0
350 : loss:  2.0274947 正确个数: 9.0
400 : loss:  1.4823446 正确个数: 10.0
450 : loss:  1.4051719 正确个数: 10.0
500 : loss:  0.91150457 正确个数: 10.0
550 : loss:  0.7835213 正确个数: 10.0
600 : loss:  0.72512466 正确个数: 10.0
650 : loss:  0.56525075 正确个数: 10.0
700 : loss:  0.4699742 正确个数: 10.0
750 : loss:  0.45453963 正确个数: 10.0
800 : loss:  0.45089394 正确个数: 10.0
*************************
图片识别结果: [0 1 2 3 4 5 6 7 8 9]
复制代码

 

 

 

 

posted @   McKay  阅读(546)  评论(0编辑  收藏  举报
编辑推荐:
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
历史上的今天:
2014-02-12 消息队列工具类(MSMQ)
点击右上角即可分享
微信分享提示