决定写tensorflow之cifar10的卷积神经网络代码阅读的文章,因为我自己静不下心来阅读,所以写文章不会让我贪快阅读从而没有思考和中断了可以接上!!!
既然是为了自己,所以就按照自己思路啦,有给他人带来烦恼,请见谅。恩,思路是从 python cifar10_train.py这个指令开始,到整个训练,后期可能会给
出数据流向介绍(一般说下次给,就代表不会给了,应该是套路吧)
我们开始吧,首先看代码清单:
------------------------cifar10.py
------------------------cifar10_eval.py
------------------------cifar10_input.py
------------------------cifar10_input_test.py
------------------------cifar10_train.py
------------------------cifar10_multi_gpu_train.py
官网教程说执行: python cifar10_train.py,这个指令后,你就可以训练了。来!!!!!!!我们来看下这个cifar10_train.py。
执行这一句话,第一执行的代码是:
if __name__ == '__main__':
tf.app.run()
然后,就跳到main函数啦:
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
main函数第一句话:
cifar10.maybe_download_and_extract()
调用了cifar10的一个函数,我们来看这个函数cifar10.py:
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
这个函数没什么特别的,没有代码阅读困难吧(对我自己啦,这里没有炫耀的成分,不要误解,牛逼的人都知道我真正要表达的是什么意思,下同),
主要功能是: 查看数据,数据在不在,如果在那就ok,不在就网上下载并解压。
: 看FLAGS.train_dir 文件夹在不在,在则删掉后创建,不在则创建
估计是运行的log吧,重新运行了当然要把老的log先干掉要,在cifar10_train.py,找到了定义,就是./cifar10_train文件夹
tf.app.flags.DEFINE_string('train_dir', './cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
回到main函数,发现只剩下train()函数啦:
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs() #艰难啊!!!!!!!!!!!看代码就是剥洋葱之数据准备
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images) #嗯,重头戏之看代码就是剥洋葱之网络构建
# Calculate loss.
loss = cifar10.loss(logits, labels) #嗯,重头戏之损失函数构建
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step) #嗯,重头戏之训练流程构建 (为什么构建?,因为在session中才运行的哦,哥哥!!!)
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
粘贴下来才发现有点小多,加油增哥,一条条看,肯定可以搞定!!!!!!!!
tf.Graph():官网解释:
A TensorFlow computation, represented as a dataflow graph
A Graph
contains a set of tf.Operation
objects, which represent units of computation; and tf.Tensor
objects, which represent the units of data that flow between operations
也就是整张图啦(不懂网上搜索下tensorflow的图,下同),表示计算等的集合,再多说下
A default Graph
is always registered, and accessible by calling tf.get_default_graph
. To add an operation to the default graph, simply call one of the functions that defines a new Operation
:
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
感觉和qt 的graphic framework 类似哦!!!
回来!!!!!!!!! with tf.Graph().as_default(): 定义了一张空白的图纸,现在我们继续走下去,准备在图上画画啦!!!!下一条:
global_step = tf.contrib.framework.get_or_create_global_step()
我们在图上话的第一步是global_step,
还是看不懂,去网上看下.....
顾名思义:
Returns and create (if necessary) the global step tensor.
Args:
graph
: The graph in which to create the global step tensor. If missing, use default graph.
Returns:
The global step tensor.
现在问题来了,什么是global step tensor
网上说:
global_step
: A scalar int32
or int64
Tensor
or a Python number. Global step to use for the decay computation. Must not be negative.
它用于衰减之类的,就是全局计数吧,暂时这么理解。
好!!!!回到train()下一句是:
with tf.device('/cpu:0'):
这是说明在with限制的区域内采用CPU计算,其中,with限制区域有:
images, labels = cifar10.distorted_inputs() #好吧,只有一句话,就是这个函数
--------------------------------------------------------------------------------------------------------------------看代码就是剥洋葱之数据准备
走!!!去cifar10.py看代码(等下记得回来):
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir: #这句是查看数据文件夹有没有,不讲解啦
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size) #是吧,这句才是核心!!!!!!!!!!!!!!!!!!!!!!
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) #先解释完上句,核心语句再来收拾这个
labels = tf.cast(labels, tf.float16)
return images, labels
走!!!去cifar10_input.py看代码(等下记得回来,这下要玩剥洋葱的游戏了)cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size):
def distorted_inputs(data_dir, batch_size):
"""Construct distorted input for CIFAR training using the Reader ops.
Args:
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)] ###为什么是6!!!!,发现在/tmp/cifar10_data/cifar-10-batches-bin/ 下面果然有6个文件,分别是:data_batch_1.bin data_batch_2.bin data_batch_...data_batch_6.bin
for f in filenames: ###现在filenames是一个列表吧,包含六个文件的列表,这还不够,还要一个一个去check存不存在。
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
##数据都存在,现在开始干活
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames) ####tf.train.string_input_producer这是啥,怎么不懂呀,恩。。。。,去网上看下.....https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer
#Output strings (e.g. filenames) to a queue for an input pipeline
#输出一系列的字符串,比如文件名,到哪里去?到一个队列中去(queue),干什么?给input pipeline 用,怎么用?继续看下面咯
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue) #好吧,这里又有个重点要说的
reshaped_image = tf.cast(read_input.uint8image, tf.float32) #咦,之前的问题,等当前解决完就可以啦。嗯,看完read_cifar10函数来解决这啦,这是格式转换用的cast
height = IMAGE_SIZE #图像尺寸,在cifar10_input.py中定义IMAGE_SIZE = 24
width = IMAGE_SIZE
# Image processing for training the network. Note the many random
# distortions applied to the image.
# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) ###这个不难,稍后解决。恩,现在来解决啦,就是随机将图像修剪到我们需要的size[height, width, 3]https://www.tensorflow.org/api_docs/python/tf/random_crop
# Randomly flip the image horizontally. 随意地水平翻转图像 Randomly flip an image horizontally (left to right) https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right
distorted_image = tf.image.random_flip_left_right(distorted_image) #不难
# Because these operations are not commutative, consider randomizing
# the order their operation.
# NOTE: since per_image_standardization zeros the mean and makes
# the stddev unit, this likely has no effect see tensorflow#1458.
distorted_image = tf.image.random_brightness(distorted_image, #Adjust the brightness of images by a random factor.https://www.tensorflow.org/api_docs/python/tf/image/random_brightness
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image, #Adjust the contrast of an image by a random factor https://www.tensorflow.org/api_docs/python/tf/image/random_contrast
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels. Subtract off the mean and divide by the variance of the pixels 减去平均值并除以像素的方差
float_image = tf.image.per_image_standardization(distorted_image) #
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties. 确保随机选取(洗牌)具有很好性能
min_fraction_of_examples_in_queue = 0.4 #队列中每个样本最小的分数
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * #在cifar10_input中定义NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)
好吧,不好用一两句话的说的就要在后面解释了,上面代码read_input = read_cifar10(filename_queue),这个要解释,就是读取.bin数据文件啦。
走!!!去cifar10_input.py看代码read_cifar10(filename_queue)
def read_cifar10(filename_queue):
"""Reads and parses examples from CIFAR10 data files.
Recommendation: if you want N-way read parallelism, call this function
N times. This will give you N independent Readers reading different
files & positions within those files, which will give better mixing of
examples.
Args:
filename_queue: A queue of strings with the filenames to read from.
Returns:#这个要仔细看下,这整个函数中定义了一个类,返回的也是这个类成员
An object representing a single example, with the following fields:
height: number of rows in the result (32)
width: number of columns in the result (32)
depth: number of color channels in the result (3)
key: a scalar string Tensor describing the filename & record number
for this example.
label: an int32 Tensor with the label in the range 0..9.
uint8image: a [height, width, depth] uint8 Tensor with the image data
"""
class CIFAR10Record(object):
pass #当你在编写一个程序时,执行语句部分思路还没有完成,这时你可以用pass语句来占位,也可以当做是一个标记,是要过后来完成的代码。
result = CIFAR10Record()
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1 # 2 for CIFAR-100 定义了标签所占的大小
result.height = 32 #定义图像高
result.width = 32 #定义图像宽
result.depth = 3 #图像是rgb所以深度为3
image_bytes = result.height * result.width * result.depth #计算一张图所占字节,注意哦,这里没包含标签的大小哦
# Every record consists of a label followed by the image, with a 这里就介绍啦,标签紧跟在图像后面
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes #这里就计算每一次record 的大小啦
# Read a record, getting filenames from the filename_queue. No #从filename_queue中提取到的filenames中读取record
# header or footer in the CIFAR-10 format, so we leave header_bytes #cifar10 数据是没有帧头和帧尾的,因此头尾大小为0
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) #FixedLengthRecordReader,简言就是读取固定长度的头,这个函数看官网:https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader
result.key, value = reader.read(filename_queue)# reader 调用read函数,返回A tuple of Tensors (key, value).key
: A string scalar Tensor.value
: A string scalar Tensor.
# Convert from a string to a vector of uint8 that is record_bytes long.将督导的value从string中转换成向量,这就涉及到 tf.decode_raw函数:https://www.tensorflow.org/api_docs/python/tf/decode_raw
record_bytes = tf.decode_raw(value, tf.uint8)
‘’‘decode_raw(
’‘’
bytes, #########bytes
: ATensor
of typestring
. All the elements must have the same length
out_type, ########out_type
: Atf.DType
from:tf.half, tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.int64
little_endian=None, ###little_endian
: An optionalbool
. Defaults toTrue
. Whether the inputbytes
are in little-endian order. Ignored forout_type
values that are stored in a single byte likeuint8
name=None #######name
: A name for the operation (optional)
)
# The first bytes represent the label, which we convert from uint8->int32.第一个字节代表标签tf.strided_slice函数:https://www.tensorflow.org/api_docs/python/tf/strided_slice
#tf.cast函数:格式转换//https://www.tensorflow.org/api_docs/python/tf/cast
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
# tf.reshape:https://www.tensorflow.org/api_docs/python/tf/reshape
depth_major = tf.reshape(
tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[result.depth, result.height, result.width])
# Convert from [depth, height, width] to [height, width, depth]. tf.transpose:https://www.tensorflow.org/api_docs/python/tf/transpose
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
返回的只有一张image哦,这个封装在类里面,至此,read_cifar10(filename_queue)结束,回到cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size)
中去啦,剥掉了一个函数咯,走!!!(然后回到上面继续看distorted_inputs Îcifar10_input.distorted_inputscifar10_input.distorted_inputs:达代表::代表箭头)
哈哈,回到distorted_inputs中从read_cifar10中看完又回来啦,可惜文字看不出动画哦,文字就直接下来,来来回回地看代码过程体现不出来。只能用不同字体咯。
现在就差下面的_generate_image_and_label_batch函数啦。 来!!!!!!!!!!!!继续:
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)
这里传入参数有一个要补充的是:batch_size = tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
"""Construct a queued batch of images and labels.
Args:
image: 3-D Tensor of [height, width, 3] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue. shuffling : 改组
Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch( #tf.train.batch 这个函数很重要哦 https://www.tensorflow.org/api_docs/python/tf/train/batch
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer.
tf.summary.image('images', images)
return images, tf.reshape(label_batch, [batch_size])
好吧,是不是很快就吧这个看完了_generate_image_and_label_batch函数,这时候数据就准备好啦!!!!
走!!!干掉了cifar10_input.distorted_inputs,现在回到cifar10.distorted_inputs(cifar10.py)中(又剥掉(洋葱)一个函数)发现,还剩下:
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) #先解释完上句,核心语句再来收拾这个
labels = tf.cast(labels, tf.float16)
return images, labels
好吧,这我就不说啦!!!!!!!!
走!!!干掉了cifar10.distored_inputs,回到 cifar10_train.py 的train()中啦,接下来是------
--------------------------------------------------------------------看代码就是剥洋葱之网络构建看代码就是剥洋葱之网络构建看代码就是剥洋葱之网络
-------------------------------------------------------------------------------------------------------------------------------------------看代码就是剥洋葱之网络构建
太多,可能有问题,所以分多个文章吧