TensorFlow-Slim 简介+Demo
github介绍:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim
基于slim实现的yolo-v3(测试可用):https://github.com/mystic123/tensorflow-yolo-v3
简介
- TF-Slim是一个轻量级tensorflow库。
- 它可以使复杂模型的定义、训练、评估测试更简单。
- 它的组件,可以与tensorflow的其他库(如tf.contrib.learn)混合使用。
- 它允许用户更紧凑地定义模型,通过消除样板代码(boilerplate code)。
Demo
import tensorflow as tf from tensorflow.contrib.layers.python.layers import layers as layers_lib from tensorflow.contrib import layers import tensorflow.contrib.slim as slim from keras.datasets import mnist import numpy as np import math print("Hello slim.") pixel_depth = 256 learning_rate = 0.01 checkpoint_dir = "./ckpts/" log_dir = "./logs/" batch_size = 1000 # Get the data, mnist.npz is in ~/.keras/datasets/mnist.npz print("Loading the MNIST data in ~/.keras/datasets/mnist.npz") (train_data, train_labels), (test_data, test_labels) = mnist.load_data() train_data = train_data .reshape(-1,28,28,1).astype(np.float32) train_labels = train_labels.reshape(-1) .astype(np.int64) test_data = test_data .reshape(-1,28,28,1).astype(np.float32) test_labels = test_labels.reshape(-1) .astype(np.int64) train_data = 2.0*train_data/pixel_depth - 1.0 test_data = 2.0*test_data /pixel_depth - 1.0 train_data = train_data[0:10000] train_labels = train_labels[0:10000] print("train data shape:", train_data.shape) print("test data shape:", test_data.shape) # slim.nets.vgg.vgg_16 def MyModel(inputs, num_classes=10, is_training=True, dropout_keep_prob=0.5, spatial_squeeze=False, scope='MyModel'): with tf.variable_scope(scope): with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0005)): net = slim.convolution2d(inputs, 8, [3, 3], 1, padding='SAME', scope='conv1') net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') net = slim.convolution2d(net, 8, [5, 5], 1, padding='SAME', scope='conv2') net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') net = slim.flatten(net, scope='flatten1') net = slim.fully_connected(net, num_classes*num_classes, activation_fn=None, scope='fc1') net = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc2') return net def train_data_batch(batch_size): if not hasattr(train_data_batch, 'train_index'): train_data_batch.train_index = 0 data_size = train_labels.shape[0] idx = np.arange(train_data_batch.train_index, train_data_batch.train_index+batch_size, 1) idx = idx % data_size train_data_batch.train_index = (train_data_batch.train_index + batch_size) % data_size yield train_data[idx] logits = MyModel(train_data) loss = slim.losses.sparse_softmax_cross_entropy(logits, train_labels) total_loss = slim.losses.get_total_loss(add_regularization_losses=False) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = slim.learning.create_train_op(total_loss, optimizer) slim.learning.train(train_op, checkpoint_dir, number_of_steps=100, save_summaries_secs=5, save_interval_secs=10) print("See you, slim.")