CNN图像识别的经典模型简述

0. 雁字无多

    我这两天要是再不学习一下之后可能时间就很少了,期末作业是做了一些了。但是,今天接到一个大任务,今年实验室招标和项目我负责。就在我写这段文字的时候,我老板进来了。正好,我就继续写吧。可是又说了半天招标的事情,说XX所好赚钱好赚钱,流程怎么怎么跑,要去招标局和公司做啥做啥。

    这周换了一本书开始看,觉得没有上一本写的好。现在感觉这些书对我帮助不是很大了,书其实没有一些CSDN上的精品博客写的好。今天想写一下CNN的经典网络吧,做图像识别分类的。要具体的讲,自2012年开始,基于CNN用于打ILSVRC图像识别大赛的几大经典网络分别是:AlexNet,VGGNets,GoogLeNet&Inception和ResNets。这几大具有统治力的模型也是每年一更新不断刷新识别率与准确率。具体每个网络的结构、异同点不妨CSDN去看看,可以根据实际需求选择一个。

    下面我也就个人的理解对这几大网络做一个背景的简要描述。

    AlexNet 在2012年被提出是被认为开创了深度学习的时代,大数据、GPU、ReLu函数(加快收敛速度)和dropout(防止过拟合)等技术也都是为它的出现奠定了基础。它包含5层卷积层和3层全连接层如图1,这是网上的一张经典图形。因为ILSVRC是个千分类问题,所以全连接层最后的softmax输出为1000维向量。不过现在实际应用中现在很难看到AlexNet,毕竟它只是作为LeNet网络的一个历史突破性版本,它具有历史里程意义,但技术价值已经不如后面的新贵们了。

图1 AlexNet

    VGGNet 是在AlexNet基础之上提出的,网络结构相似。不同点在于,VGGNet是在每一层卷积层上连续卷积2~4次,结构上与AlexNet相比并没有做什么改变,在此不再赘述。VGGNet对深度学习最大的贡献莫过于:不考虑其他因素(计算存储等),CNN网络的深度不断加深(增加卷积次数,通俗讲就是使计算复杂化)可以提升准确率。这也成为了现在做机器学习提升效果的核心方法之一:加深网络、数据增强模型融合。另外调参也很重要。虽然VGGNet参数很多,但由于设计上的很多细节,它能够很快收敛,具体原因我需要提一下吗....好吧,主要是卷积核大小比较小。最后还需要说一点,无脑的加深网络深度是不能得到效果的线性提升,这是因为参数变多之后误差函数的梯度弥散问题,导致网络在训练时无法收敛(无论从顺着哪个参数的梯度进行下降都无法使误差减小,因为梯度求导几乎为0)。比如VGGNet在超过20层后就效果就会下降。这个梯度弥散问题在后续的ResNets得到了一定程度的解决。

    GoogLeNet 与VGGNet在相比,在内存和计算消耗方面有非常大的优势。AlexNet增多卷积次数,不可避免使其增多了权值参数个数,它共有6000万个权值参数,是AlexNet的三倍以上;而GoogleLeNet只有500万个参数,所以在内存较小的移动端GoogLeNet有着更广泛的应用。当时提出GoogLeNet,就是因为Google的学者们虽然同意像AlexNet那样加深网络深度可以提高识别准确率,但如何更加有效的加深网络深度,使其能够在保证准确率提升的前提下尽可能的“轻量”,所以开启了名为Inception的项目工程。GoogLeNet就是InceptionV1版本,现在已经发布到InceptionV4版本——结合ResNets的GoogLeNet。具体的发展与区别可以百度了解一下。结构图是实在很复杂如图2。

 

图2 GoogLeNet

    ResNets 是2016年微软亚研院的何凯明博士及其团队提出的深度残差网络。其最大的亮点就是通过设计的残差网络结构,避免了随着随着网络层数加深而产生的梯度消失或梯度爆炸的问题(该问题被称为深度网络的退化问题,degradation problem),不但能使深度神经网络的收敛速度更快、精度更高,而且让加深网络深度来提高网络效果成为可能。在学习的时候,很多博客和教材会把ResNets的网络结构同VGG进行对比学习,如图3。

 

图3 ReNets

1. 写得相思几许

    这几大网络github上有很好的项目工程,python3.6+tensorflow1.7亲测可用。贴一个链接吧:https://github.com/MachineLP/train_arch

    这一小节还是得弄一点干货吧哈哈哈,那就上一个我自己调通了把数据集下载这些都合成在一个py文件的GoogLeNet吧。InceptionV3,代码参考上述链接。

    

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf


FLAGS = None
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'

# NodeLookup类负责将分类器输出的类别编号与人类可读的标签名称对应起来
class NodeLookup(object):
  def __init__(self,
               label_lookup_path=None,
               uid_lookup_path=None):
    if not label_lookup_path:
      label_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
    if not uid_lookup_path:
      uid_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
    self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

# 为每一个softmax节点读取人类可读的类标英文名字
  def load(self, label_lookup_path, uid_lookup_path):
    if not tf.gfile.Exists(uid_lookup_path):
      tf.logging.fatal('File does not exist %s', uid_lookup_path)
    if not tf.gfile.Exists(label_lookup_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
    uid_to_human = {}
    p = re.compile(r'[n\d]*[ \S,]*')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      uid = parsed_items[0]
      human_string = parsed_items[2]
      uid_to_human[uid] = human_string

    # Loads mapping from string UID to integer node ID.
    node_id_to_uid = {}
    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    for line in proto_as_ascii:
      if line.startswith('  target_class:'):
        target_class = int(line.split(': ')[1])
      if line.startswith('  target_class_string:'):
        target_class_string = line.split(': ')[1]
        node_id_to_uid[target_class] = target_class_string[1:-2]

    # Loads the final mapping of integer node ID to human-readable string
    node_id_to_name = {}
    for key, val in node_id_to_uid.items():
      if val not in uid_to_human:
        tf.logging.fatal('Failed to locate: %s', val)
      name = uid_to_human[val]
      node_id_to_name[key] = name
    return node_id_to_name


  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


# 从protocol buffer文件中反序列化出inception-v3模型及参数
def create_graph():
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

   
# 使用v3模型对image图片进行分类,并输出top5置信度的类别预测
def run_inference_on_image(image):
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph()

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.

    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()
    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

# 下载模型存档并解压
def maybe_download_and_extract():
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):

    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def main(_):
  maybe_download_and_extract()
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':

  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_dir',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """

  )

  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )

  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

    运行效果如下:

图4 运行结果

    简要分析下吧。程序使用InceptionV3的存档模型对一张可爱的滚滚(熊猫)图片进行识别。结果显示该图片属于panda的概率为89.1%,Indir的概率为0.8%,lesser panda的概率为0.3%等等。

    哈哈哈哈,当把照片换成我家小天使时。小家伙你也太鼠头鼠脑的了吧哈哈哈:

    

 

posted on 2018-06-07 16:42  Allen.c  阅读(695)  评论(0编辑  收藏  举报

导航