tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛

tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用。

数据目录在data,data下放了汉字识别图片:

data$ ls
0  1  10  11  12  13  14  15  16  2  3  4  5  6  7  8  9
datag$ ls 0
xxx.png yyy.png ....

代码:

 如果将get model里的模型层数加非常深,训练时候很可能不会收敛,精度一直停留下1%以内。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# -*- coding: utf-8 -*-
 
 
from __future__ import division, print_function, absolute_import
 
import os
import numpy as np
import pickle
import tflearn
 
from PIL import Image
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, avg_pool_2d
from tflearn.layers.merge_ops import merge
from tflearn.layers.estimator import regression
from tflearn.data_utils import to_categorical, shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tflearn.layers.conv import highway_conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization, batch_normalization
 
def resize_image(in_image, new_width, new_height, out_image=None,
                 resize_mode=Image.ANTIALIAS):
    """ Resize an image.
    Arguments:
        in_image: `PIL.Image`. The image to resize.
        new_width: `int`. The image new width.
        new_height: `int`. The image new height.
        out_image: `str`. If specified, save the image to the given path.
        resize_mode: `PIL.Image.mode`. The resizing mode.
  
    Returns:
        `PIL.Image`. The resize image.
    """
    img = in_image.resize((new_width, new_height), resize_mode)
    if out_image:
        img.save(out_image)
    return img
 
 
def convert_color(in_image, mode):
    """ Convert image color with provided `mode`. """
    return in_image.convert(mode)
 
 
def pil_to_nparray(pil_image):
    """ Convert a PIL.Image to numpy array. """
    pil_image.load()
    return np.asarray(pil_image, dtype="float32")
 
 
def iterbrowse(path):
    for home, dirs, files in os.walk(path):
        for filename in files:
            yield os.path.join(home, filename)
 
 
def directory_to_samples(directory, flags):
    """ Read a directory, and list all subdirectories files as class sample """
    samples = []
    targets = []
    # label class is from 0 !!!
    label = 0
    try# Python 2
        classes = sorted(os.walk(directory).next()[1])
    except Exception:  # Python 3
        classes = sorted(os.walk(directory).__next__()[1])
    for c in classes:
        c_dir = os.path.join(directory, c)
        try# Python 2
            walk = os.walk(c_dir).next()
        except Exception:  # Python 3
            walk = os.walk(c_dir).__next__()
        for sample in walk[2]:
            if any(flag in sample for flag in flags):
                samples.append(os.path.join(c_dir, sample))
                targets.append(label)
        label += 1
    return samples, targets
 
 
# Get the pixel from the given image
def get_pixel(image, i, j):
    # Inside image bounds?
    width, height = image.size
    if i > width or j > height:
      return None
 
    # Get Pixel
    pixel = image.getpixel((i, j))
    return pixel
 
 
# Create a Grayscale version of the image
def convert_to_one_channel(image):
    # !!! I assume that the png file is grayscale. And R == G == B !!!! So I check it...
    """
    for i in range(len(image)):
        for j in range(len(image[i])):
            pixel = image[i][j]
            # Get R, G, B values (This are int from 0 to 255)
            assert len(pixel) == 3
            red = pixel[0]
            green = pixel[1]
            blue = pixel[2]
            assert red == green == blue
            assert 0 <= red <= 1
    """
    # Just extract 1 channel data
    return image[:, :, [0]]
 
 
 
def image_dirs_to_samples(directory, resize=None, convert_gray=False,
                          filetypes=None):
    print("Starting to parse images...")
    samples, targets = directory_to_samples(directory, flags=filetypes)
    for i, s in enumerate(samples):
        print("Process %d th file %s" % (i+1, s))
        samples[i] = Image.open(s)  # Load an image, returns PIL.Image.
        if resize:
            ######################## TODO #######################
            samples[i] = resize_image(samples[i], resize[0],
                                      resize[1])
        ######################### TODO ####################### convert to more data
        # if convert_gray:
        #    samples[i] = convert_color(samples[i], 'L')
        samples[i] = pil_to_nparray(samples[i])
        samples[i] /= 255.  # ormalize a list of sample image data in the range of 0 to 1
        samples[i] = convert_to_one_channel(samples[i]) # just want 1 channel data
    print("Parsing Done!")
    return samples, targets
 
 
def load_data(dirname, resize_pics=(128, 128), shuffle_data=True):
    dataset_file = os.path.join(dirname, 'data.pkl')
    try:
        X, Y, org_labels = pickle.load(open(dataset_file, 'rb'))
    except Exception:
        # X, Y = image_dirs_to_samples(os.path.join(dirname, 'train/'), resize_pics, False, ['.jpg', '.png'])
        X, Y = image_dirs_to_samples(dirname, resize_pics, False,
                                     ['.jpg', '.png'])  # TODO, memory is too small to load all data
        org_labels = Y
        Y = to_categorical(Y, np.max(Y) + 1# First class is '0', Convert class vector (integers from 0 to nb_classes)
        if shuffle_data:
            X, Y, org_labels = shuffle(X, Y, org_labels)
        pickle.dump((X, Y, org_labels), open(dataset_file, 'wb'))
    return X, Y, org_labels
 
 
class EarlyStoppingCallback(tflearn.callbacks.Callback):
    def __init__(self, val_acc_thresh):
        # Store a validation accuracy threshold, which we can compare against
        # the current validation accuracy at, say, each epoch, each batch step, etc.
        self.val_acc_thresh = val_acc_thresh
 
    def on_epoch_end(self, training_state):
        """
        This is the final method called in trainer.py in the epoch loop.
        We can stop training and leave without losing any information with a simple exception.
        """
        # print dir(training_state)
        print("Terminating training at the end of epoch", training_state.epoch)
        if training_state.val_acc >= self.val_acc_thresh and training_state.acc_value >= self.val_acc_thresh:
            raise StopIteration
 
    def on_train_end(self, training_state):
        """
        Furthermore, tflearn will then immediately call this method after we terminate training,
        (or when training ends regardless). This would be a good time to store any additional
        information that tflearn doesn't store already.
        """
        print("Successfully left training! Final model accuracy:", training_state.acc_value)
 
 
def get_model(width, height, classes=40):
    # TODO, modify model
    # Real-time data preprocessing
    img_prep = tflearn.ImagePreprocessing()
    # Real-time data preprocessing
    img_prep = tflearn.ImagePreprocessing()
    img_prep.add_featurewise_zero_center(per_channel=True)
    img_prep.add_featurewise_stdnorm()
    network = input_data(shape=[None, width, height, 1], data_preprocessing=img_prep)  # if RGB, 224,224,3
    network = conv_2d(network, 32, 3, activation='relu')
    network = max_pool_2d(network, 2)
    network = conv_2d(network, 64, 3, activation='relu')
    network = conv_2d(network, 64, 3, activation='relu')
    network = max_pool_2d(network, 2)
    network = fully_connected(network, 512, activation='relu')
    network = dropout(network, 0.5)
    network = fully_connected(network, classes, activation='softmax')
    network = regression(network, optimizer='adam',
                         loss='categorical_crossentropy',
                         learning_rate=0.001)
    model = tflearn.DNN(network, tensorboard_verbose=0)
    return model
 
 
if __name__ == "__main__":
    width, height = 32, 32
    X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
    trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
    print("sample data:")
    print(trainX[0])
    print(trainY[0])
    print(testX[-1])
    print(testY[-1])
 
    model = get_model(width, height, classes=100)
 
    filename = 'cnn_handwrite-acc0.8.tflearn'
    # try to load model and resume training
    #try:
    #    model.load(filename)
    #    print("Model loaded OK. Resume training!")
    #except:
    #    pass
 
    # Initialize our callback with desired accuracy threshold.
    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.9)
    try:
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
    except StopIteration as e:
        print("OK, stop iterate!Good!")
 
    model.save(filename)
 
    # predict all data and calculate confusion_matrix
    model.load(filename)
 
    pro_arr =model.predict(X)
    predict_labels = np.argmax(pro_arr, axis=1)
    print(classification_report(org_labels, predict_labels))
    print(confusion_matrix(org_labels, predict_labels))

 

1
 

 运行效果:100个汉字2分钟就可以达到95%精度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---------------------------------
Run id: cnn_handwrite
Log directory: /tmp/tflearn_logs/
---------------------------------
Preprocessing... Calculating mean over all dataset (this may take long)...
Mean: [ 0.89235026] (To avoid repetitive computation, add it to argument 'mean' of `add_featurewise_zero_center`)
---------------------------------
Preprocessing... Calculating std over all dataset (this may take long)...
STD: 0.192279 (To avoid repetitive computation, add it to argument 'std' of `add_featurewise_stdnorm`)
---------------------------------
Training samples: 19094
Validation samples: 4774
--
Training Step: 597  | total loss: 0.70288 | time: 40.959ss
| Adam | epoch: 001 | loss: 0.70288 - acc: 0.7922 | val_loss: 0.54380 - val_acc: 0.8460 -- iter: 19094/19094
--
Terminating training at the end of epoch 1
 Training Step: 1194  | total loss: 0.48860 | time: 40.245s
| Adam | epoch: 002 | loss: 0.48860 - acc: 0.8783 | val_loss: 0.37020 - val_acc: 0.8923 -- iter: 19094/19094
--
Terminating training at the end of epoch 2
Training Step: 1791  | total loss: 0.35790 | time: 41.315ss
| Adam | epoch: 003 | loss: 0.35790 - acc: 0.9090 | val_loss: 0.34719 - val_acc: 0.9049 -- iter: 19094/19094
--
Terminating training at the end of epoch 3
Successfully left training! Final model accuracy: 0.908959209919
OK, stop iterate!Good!
             precision    recall  f1-score   support
 
          0       1.00      0.99      0.99       239
          1       0.95      0.96      0.96       237
          2       0.91      0.98      0.94       240
          3       0.90      0.98      0.94       239
          4       0.96      0.98      0.97       239
          5       0.94      0.97      0.96       239
          6       0.98      0.98      0.98       239
          7       0.84      0.99      0.91       240
          8       0.99      0.87      0.93       239
          9       0.95      0.98      0.96       239
         10       0.97      0.94      0.96       240
         11       0.95      0.98      0.97       240
         12       0.92      0.99      0.95       240
         13       0.95      0.96      0.96       239
         14       0.94      0.94      0.94       236
         15       0.94      0.97      0.96       240
         16       0.94      0.98      0.96       240
         17       0.97      0.99      0.98       240
         18       0.94      0.93      0.94       240
         19       1.00      0.95      0.98       239
         20       0.96      0.98      0.97       240
         21       0.98      0.91      0.95       239
         22       0.97      0.95      0.96       239
         23       1.00      0.97      0.98       239
         24       0.94      0.98      0.96       240
         25       0.98      0.98      0.98       237
         26       0.91      1.00      0.95       239
         27       0.91      0.96      0.93       239
         28       0.97      0.88      0.92       239
         29       1.00      0.98      0.99       240
         30       0.99      0.94      0.96       239
         31       0.97      0.97      0.97       237
         32       0.94      0.98      0.96       236
         33       0.94      0.96      0.95       239
         34       0.98      0.99      0.98       239
         35       0.99      0.98      0.99       240
         36       0.96      0.92      0.94       239
         37       1.00      0.93      0.96       240
         38       0.96      0.99      0.98       238
         39       0.98      0.97      0.97       238
         40       0.92      0.90      0.91       240
         41       0.96      0.97      0.96       237
         42       0.98      0.97      0.97       240
         43       0.95      0.96      0.95       239
         44       0.97      0.96      0.97       239
         45       0.95      0.94      0.95       239
         46       0.93      0.96      0.94       232
         47       0.98      0.91      0.94       237
         48       0.95      0.97      0.96       239
         49       0.97      0.80      0.88       226
         50       0.90      0.95      0.92       240
         51       0.98      0.99      0.99       236
         52       0.96      0.90      0.93       240
         53       0.99      0.96      0.97       235
         54       0.97      0.93      0.95       240
         55       0.99      0.98      0.99       240
         56       0.97      0.92      0.95       239
         57       0.97      0.97      0.97       239
         58       1.00      0.98      0.99       238
         59       0.92      0.98      0.95       240
         60       0.99      0.90      0.94       240
         61       1.00      0.99      0.99       238
         62       0.92      0.95      0.94       239
         63       0.92      0.98      0.95       238
         64       0.98      0.92      0.95       240
         65       0.99      0.92      0.95       239
         66       0.98      0.99      0.99       240
         67       0.95      0.95      0.95       240
         68       0.96      0.98      0.97       239
         69       0.97      0.97      0.97       239
         70       0.98      0.94      0.96       240
         71       0.91      0.96      0.93       239
         72       0.98      0.97      0.97       239
         73       0.99      0.89      0.94       240
         74       0.97      0.99      0.98       237
         75       0.89      0.97      0.92       240
         76       0.97      0.96      0.97       241
         77       0.89      0.91      0.90       240
         78       1.00      0.89      0.94       239
         79       0.90      0.98      0.94       239
         80       0.89      0.96      0.92       240
         81       0.96      0.71      0.81       225
         82       0.95      1.00      0.97       238
         83       0.67      0.96      0.79       239
         84       0.97      0.85      0.91       240
         85       0.95      0.98      0.96       239
         86       0.99      0.93      0.96       240
         87       0.98      0.91      0.94       239
         88       0.97      0.97      0.97       240
         89       0.97      0.94      0.95       239
         90       0.97      0.96      0.96       236
         91       0.91      0.97      0.94       239
         92       0.98      0.95      0.96       240
         93       0.98      0.97      0.98       239
         94       0.98      0.95      0.97       240
         95       0.98      0.99      0.99       239
         96       0.95      0.97      0.96       240
         97       0.98      0.97      0.98       239
         98       0.95      0.98      0.97       237
         99       0.97      0.96      0.97       239
 
avg / total       0.96      0.95      0.95     23868
 
[[237   0   0 ...,   0   0   0]
 0 228   0 ...,   0   0   0]
 0   0 235 ...,   0   0   0]
 ...,
 0   0   0 ..., 233   0   0]
 0   0   0 ...,   0 233   0]
 0   0   0 ...,   0   0 230]]

 更多模型见:http://www.cnblogs.com/bonelee/p/8978060.html

 

将上述模型保存并给TensorFlow使用,仅仅在保存模型前加del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:],仅仅保留inference时候的OP(如果需要retrain注意),如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
model = get_model(width, height, classes=100)
 
filename = 'cnn_handwrite-acc0.8.tflearn'
# try to load model and resume training
#try:
#    model.load(filename)
#    print("Model loaded OK. Resume training!")
#except:
#    pass
 
# Initialize our callback with desired accuracy threshold.
early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.8)
try:
    model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
              snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
              show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
    print("OK, stop iterate!Good!")
 
model.save(filename)
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
 
"""
# print op name
with tf.Session() as sess:
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        for v in sess.graph.get_operations():
            print(v.name)
"""
 
filename = 'cnn_handwrite-acc0.8.infer.tflearn'
model.save(filename)

 参考:http://www.cnblogs.com/bonelee/p/8445261.html 里的脚本,修改:

output_node_names = "FullyConnected/Softmax"
通常为:
output_node_names = "FullyConnected/Softmax"
或者
output_node_names = "FullyConnected_1/Softmax"
output_node_names = "FullyConnected_2/Softmax"
就看你使用的全连接层数,上面分别是1,2,3层。
最后,tensorflow里的使用:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def inference(image):
    print('inference')
    temp_image = Image.open(image).convert('L')
    temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
    temp_image = np.asarray(temp_image) / 255.0
    temp_image = temp_image.reshape([-1, 32, 32, 1])
    from tensorflow.python.platform import gfile
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open("frozen_model.pb", "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tensors = tf.import_graph_def(output_graph_def, name="")
            #print tensors
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            op = sess.graph.get_operations()
            """
            for m in op:
                print(m.values())
            """
            op = sess.graph.get_tensor_by_name("FullyConnected_1/Softmax:0")
            input_tensor = sess.graph.get_tensor_by_name('InputData/X:0')
            probs = sess.run(op,feed_dict = {input_tensor:temp_image})
            print probs
             
            result = []
            for word in probs:
                  result.append(np.argsort(-word)[:3])
            return result
 
 
def main(_):
        image_path = './data/test/00098/104405.png'
        #image_path = '../data/00010/17724.png'
        final_predict_val = inference(image_path)
        logger.info('the result info label {0} predict index {1}'.format(98, final_predict_val))

 一般,输入TensorFlow input name默认为InputData/X,但只是op,如果要tensor的话,加上数字0,也就是:InputData/X:0

同理,FullyConnected_1/Softmax:0。

最后预测效果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
[[  8.42533936e-08   1.60850794e-11   2.60133332e-10   2.42555542e-14
    4.96124599e-08   4.45251297e-15   3.98175590e-11   1.64476592e-11
    7.03968351e-13   5.42319011e-12   8.55469237e-11   4.91866422e-13
    1.77282828e-07   4.05237593e-10   3.13049003e-10   1.34780919e-11
    2.05803235e-06   2.87827305e-07   1.47789994e-12   2.53391891e-11
    3.77086790e-13   2.02639586e-10   9.03167027e-13   3.96698889e-11
    1.30850096e-11   5.71980917e-12   3.03487374e-11   2.04132298e-14
    6.25303683e-13   1.46122332e-07   2.17450633e-07   1.69623715e-09
    6.80857757e-12   2.52643609e-13   6.56771096e-11   8.55152287e-16
    1.34496514e-09   1.22644633e-06   1.12011307e-07   7.93476283e-05
    8.24334611e-12   4.77531155e-14   9.39397757e-13   2.38438267e-14
    2.11416329e-10   5.54395712e-08   2.30046147e-12   2.63584043e-10
    4.70621564e-16   5.14432724e-12   6.42602327e-09   1.62485829e-13
    7.39078274e-08   3.19146315e-12   5.25887156e-09   1.35877786e-13
    1.39127886e-13   2.11998293e-13   9.09501097e-09   9.46486750e-07
    2.47498733e-09   2.74523763e-12   1.02716433e-14   1.02069058e-17
    3.09356682e-16   1.51022904e-15   9.34333665e-13   2.62195051e-14
    3.38079781e-16   7.43019903e-13   1.92409091e-13   3.86611994e-13
    2.61276265e-12   1.07969211e-09   1.30814548e-09   2.44038188e-14
    9.79275905e-13   1.41007803e-10   6.15137758e-12   2.08893070e-10
    1.34751668e-14   2.76824767e-15   7.84100464e-16   7.70873335e-15
    5.45704757e-12   3.69386271e-18   2.06012223e-13   1.62567273e-14
    1.54544960e-03   2.05292008e-06   1.31726174e-09   7.04993663e-09
    4.11338266e-03   3.19344110e-07   3.96519717e-05   2.26919351e-12
    2.39114349e-12   2.35558744e-07   9.94213998e-01   1.10125060e-11]]
the result info label 98 predict index [array([98, 92, 88])]

 

 
 

 

posted @   bonelee  阅读(1330)  评论(4编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2017-04-25 mongodb AND查询遇到多个index时候可能会做交集——和复合索引不同
2017-04-25 美国诚实签经验(最全集合)
2017-04-25 美国诚实签经验贴汇总
点击右上角即可分享
微信分享提示