Tensorflow ——神经网络

Training Data Eval:
Num examples: 55000 Num correct: 52015 Precision @ 1: 0.9457
Validation Data Eval:
Num examples: 5000 Num correct: 4740 Precision @ 1: 0.9480
Test Data Eval:
Num examples: 10000 Num correct: 9456 Precision @ 1: 0.9456

 

  1 import tensorflow as tf
  2 import input_data
  3 import math
  4 
  5 NUM_CLASSES = 10
  6 IMAGE_SIZE = 28
  7 IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  8 flags = tf.app.flags
  9 FLAGS = flags.FLAGS
 10 flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
 11 flags.DEFINE_integer('max_steps', 10000, 'Number of steps to run trainer.')
 12 flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
 13 flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
 14 flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
 15                      'Must divide evenly into the dataset sizes.')
 16 flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
 17 flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
 18                      'for unit testing.')
 19 
 20 def inference(images, hidden1_units, hidden2_units):
 21   with tf.name_scope('hidden1'):
 22     weights = tf.Variable(
 23         tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
 24                             stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
 25         name='weights')
 26     biases = tf.Variable(tf.zeros([hidden1_units]),
 27                          name='biases')
 28     hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
 29   with tf.name_scope('hidden2'):
 30     weights = tf.Variable(
 31         tf.truncated_normal([hidden1_units, hidden2_units],
 32                             stddev=1.0 / math.sqrt(float(hidden1_units))),
 33         name='weights')
 34     biases = tf.Variable(tf.zeros([hidden2_units]),
 35                          name='biases')
 36     hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
 37   with tf.name_scope('softmax_linear'):
 38     weights = tf.Variable(
 39         tf.truncated_normal([hidden2_units, NUM_CLASSES],
 40                             stddev=1.0 / math.sqrt(float(hidden2_units))),
 41         name='weights')
 42     biases = tf.Variable(tf.zeros([NUM_CLASSES]),
 43                          name='biases')
 44     logits = tf.matmul(hidden2, weights) + biases
 45   return logits
 46 
 47 def loss(logits, labels):
 48   labels = tf.to_int64(labels)
 49   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
 50       logits, labels, name='xentropy')
 51   loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
 52   return loss
 53 
 54 def training(loss, learning_rate):
 55   tf.scalar_summary(loss.op.name, loss)
 56   optimizer = tf.train.GradientDescentOptimizer(learning_rate)
 57   global_step = tf.Variable(0, name='global_step', trainable=False)
 58   train_op = optimizer.minimize(loss, global_step=global_step)
 59   return train_op
 60 
 61 def evaluation(logits, labels):
 62   correct = tf.nn.in_top_k(logits, labels, 1)
 63   return tf.reduce_sum(tf.cast(correct, tf.int32))
 64 
 65 def placeholder_inputs(batch_size):
 66   images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
 67                                                          IMAGE_PIXELS))
 68   labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
 69   return images_placeholder, labels_placeholder
 70 
 71 
 72 def fill_feed_dict(data_set, images_pl, labels_pl):
 73   images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
 74                                                  FLAGS.fake_data)
 75   feed_dict = {
 76       images_pl: images_feed,
 77       labels_pl: labels_feed,
 78   }
 79   return feed_dict
 80 
 81 
 82 def do_eval(sess,
 83             eval_correct,
 84             images_placeholder,
 85             labels_placeholder,
 86             data_set):
 87   true_count = 0
 88   steps_per_epoch = data_set.num_examples // FLAGS.batch_size
 89   num_examples = steps_per_epoch * FLAGS.batch_size
 90   for step in range(steps_per_epoch):
 91     feed_dict = fill_feed_dict(data_set,
 92                                images_placeholder,
 93                                labels_placeholder)
 94     true_count += sess.run(eval_correct, feed_dict=feed_dict)
 95   precision = true_count / num_examples
 96   print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
 97         (num_examples, true_count, precision))
 98 
 99 def run_training():
100   data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
101   print(FLAGS.train_dir, FLAGS.fake_data)
102   with tf.Graph().as_default():
103     images_placeholder, labels_placeholder = placeholder_inputs(
104         FLAGS.batch_size)
105     logits = inference(images_placeholder,
106                              FLAGS.hidden1,
107                              FLAGS.hidden2)
108     loss_minist = loss(logits, labels_placeholder)
109     train_op = training(loss_minist, FLAGS.learning_rate)
110     eval_correct = evaluation(logits, labels_placeholder)
111     summary = tf.merge_all_summaries()
112     init = tf.initialize_all_variables()
113     sess = tf.Session()
114     summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
115     sess.run(init)
116     for step in range(FLAGS.max_steps):
117       feed_dict = fill_feed_dict(data_sets.train,
118                                  images_placeholder,
119                                  labels_placeholder)
120       _, loss_value = sess.run([train_op, loss_minist],
121                                feed_dict=feed_dict)
122 
123       if step % 100 == 0:
124         print('Step %d: loss = %.2f' % (step, loss_value))
125         summary_str = sess.run(summary, feed_dict=feed_dict)
126         summary_writer.add_summary(summary_str, step)
127         summary_writer.flush()
128       if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
129         print('Training Data Eval:')
130         do_eval(sess,
131                 eval_correct,
132                 images_placeholder,
133                 labels_placeholder,
134                 data_sets.train)
135         print('Validation Data Eval:')
136         do_eval(sess,
137                 eval_correct,
138                 images_placeholder,
139                 labels_placeholder,
140                 data_sets.validation)
141         print('Test Data Eval:')
142         do_eval(sess,
143                 eval_correct,
144                 images_placeholder,
145                 labels_placeholder,
146                 data_sets.test)
147 run_training()

 

posted on 2016-12-06 22:43  1357  阅读(269)  评论(0编辑  收藏  举报

导航