TensorFlow入门 - 使用TensorFlow甄别图片中的时尚单品

##使用TensorFlow甄别图片中的时尚单品

MNIST数据集是一个经典的机器学习数据集,该数据集由像素大小28*28的手写数字图片构成,每一个图片都由该图片对应的数字标记,经常用于实现用机器学习模型识别其中的数字来完成对机器学习算法的性能对标。

本例并没有直接使用MNIST数据集,为了使我们的实现更有趣一点,我们采用了Zalando发布的fashion-mnist数据集。该数据集与MNIST格式一致,但数字被换成了10个种类的挎包、服饰、鞋子。

以下是Jupyter Notebook中的整个实现过程:

在tensorflow虚拟环境中启动jupyter notebook

steve@steve-Lenovo-V2000:~$ source activate tensorflow
(tensorflow) steve@steve-Lenovo-V2000:~$ jupyter notebook
In[1]	
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
tf.logging.set_verbosity(tf.logging.INFO)

#pixels即特征名称 shape特征的个数 28*28 = 784
feature_columns = [tf.feature_column.numeric_column('pixels', 							  shape=784)]

In[2]		
#线性分类器
    classifier = tf.estimator.LinearClassifier(
    				 feature_columns = feature_columns,
    			 	 n_classes = 10,
    			  	 model_dir = "tmp/fashion_mnist/linear"
     	 )

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'tmp/fashion_mnist/linear', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100}

In[3]	
#输入函数
def input_fn(data, batch_size, num_epochs, shuffle):
    		return tf.estimator.inputs.numpy_input_fn(
            x = {'pixels': data.images},
            y = data.labels.astype(np.int64),
            batch_size = batch_size,
            num_epochs = num_epochs,
            shuffle = shuffle
           )

In[4]	
#从下载数据集的路径中读取数据保存到对象DATA_SETS中
DATA_SETS = input_data.read_data_sets("/home/steve/fashion_mnist")

Extracting /home/steve/fashion_mnist/train-images-idx3-ubyte.gz
Extracting /home/steve/fashion_mnist/train-labels-idx1-ubyte.gz
Extracting /home/steve/fashion_mnist/t10k-images-idx3-ubyte.gz
Extracting /home/steve/fashion_mnist/t10k-labels-idx1-ubyte.gz

In[5]	
#训练
classifier.train(input_fn = input_fn(DATA_SETS.train, batch_size = 100, 					num_epochs = 3, shuffle = True))

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from tmp/fashion_mnist/linear/model.ckpt-16500
INFO:tensorflow:Saving checkpoints for 16501 into tmp/fashion_mnist/linear/model.ckpt.
INFO:tensorflow:loss = 38.0863, step = 16501
INFO:tensorflow:global_step/sec: 344.156
INFO:tensorflow:loss = 36.0887, step = 16601 (0.291 sec)
INFO:tensorflow:global_step/sec: 322.998
INFO:tensorflow:loss = 42.7409, step = 16701 (0.311 sec)
INFO:tensorflow:global_step/sec: 281.274
INFO:tensorflow:loss = 41.0631, step = 16801 (0.354 sec)
INFO:tensorflow:global_step/sec: 477.183
INFO:tensorflow:loss = 34.4951, step = 16901 (0.210 sec)
INFO:tensorflow:global_step/sec: 518.435
INFO:tensorflow:loss = 29.0231, step = 17001 (0.193 sec)
INFO:tensorflow:global_step/sec: 488.121
INFO:tensorflow:loss = 25.1947, step = 17101 (0.205 sec)
INFO:tensorflow:global_step/sec: 508.804
INFO:tensorflow:loss = 29.5358, step = 17201 (0.197 sec)
INFO:tensorflow:global_step/sec: 506.08
INFO:tensorflow:loss = 41.3665, step = 17301 (0.198 sec)
INFO:tensorflow:global_step/sec: 383.223
INFO:tensorflow:loss = 37.4047, step = 17401 (0.261 sec)
INFO:tensorflow:global_step/sec: 408.678
INFO:tensorflow:loss = 31.8637, step = 17501 (0.245 sec)
INFO:tensorflow:global_step/sec: 314.574
INFO:tensorflow:loss = 36.5287, step = 17601 (0.318 sec)
INFO:tensorflow:global_step/sec: 431.968
INFO:tensorflow:loss = 27.3364, step = 17701 (0.231 sec)
INFO:tensorflow:global_step/sec: 423.543
INFO:tensorflow:loss = 43.4093, step = 17801 (0.237 sec)
INFO:tensorflow:global_step/sec: 389.708
INFO:tensorflow:loss = 36.8287, step = 17901 (0.256 sec)
INFO:tensorflow:global_step/sec: 487.954
INFO:tensorflow:loss = 40.7856, step = 18001 (0.205 sec)
INFO:tensorflow:global_step/sec: 513.674
INFO:tensorflow:loss = 28.5917, step = 18101 (0.195 sec)
INFO:tensorflow:Saving checkpoints for 18150 into tmp/fashion_mnist/linear/model.ckpt.
INFO:tensorflow:Loss for final step: 41.4883.

Out[1]	<tensorflow.python.estimator.canned.linear.LinearClassifier at 				0x7f25ce5f2b00>

In[6]	
#计算准确度

accuracy_score = classifier.evaluate(input_fn = input_fn(DATA_SETS.test, 					 					batch_size = 100, num_epochs = 1, 											  shuffle = False))['accuracy']

INFO:tensorflow:Starting evaluation at 2018-03-07-08:00:07
INFO:tensorflow:Restoring parameters from tmp/fashion_mnist/linear/model.ckpt-18150
INFO:tensorflow:Finished evaluation at 2018-03-07-08:00:08
INFO:tensorflow:Saving dict for global step 18150: accuracy = 0.8446, average_loss = 0.450767, global_step = 18150, loss = 45.0767

In[7]	
#深度分类器 使用了深度神经网络
deep_classifier = tf.estimator.DNNClassifier(
    						feature_columns = feature_columns,
    						n_classes = 10,
    						hidden_units =[100, 75, 50],
    						model_dir = "tmp/fashion_mnist/deep"
                          )

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'tmp/fashion_mnist/deep', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100}


In[8]	
#训练
deep_classifier.train(input_fn = input_fn(DATA_SETS.train, 
batch_size = 100, num_epochs = 2, shuffle = True))
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from tmp/fashion_mnist/deep/model.ckpt-6600
INFO:tensorflow:Saving checkpoints for 6601 into tmp/fashion_mnist/deep/model.ckpt.
INFO:tensorflow:loss = 37.3415, step = 6601
INFO:tensorflow:global_step/sec: 185.682
INFO:tensorflow:loss = 25.2046, step = 6701 (0.541 sec)
INFO:tensorflow:global_step/sec: 182.251
INFO:tensorflow:loss = 37.8857, step = 6801 (0.547 sec)
INFO:tensorflow:global_step/sec: 243.679
INFO:tensorflow:loss = 20.0217, step = 6901 (0.412 sec)
INFO:tensorflow:global_step/sec: 191.407
INFO:tensorflow:loss = 18.9174, step = 7001 (0.521 sec)
INFO:tensorflow:global_step/sec: 210.536
INFO:tensorflow:loss = 26.6609, step = 7101 (0.475 sec)
INFO:tensorflow:global_step/sec: 225.08
INFO:tensorflow:loss = 29.2771, step = 7201 (0.444 sec)
INFO:tensorflow:global_step/sec: 238.886
INFO:tensorflow:loss = 26.5615, step = 7301 (0.419 sec)
INFO:tensorflow:global_step/sec: 209.95
INFO:tensorflow:loss = 24.2403, step = 7401 (0.477 sec)
INFO:tensorflow:global_step/sec: 214.734
INFO:tensorflow:loss = 20.5827, step = 7501 (0.465 sec)
INFO:tensorflow:global_step/sec: 223.289
INFO:tensorflow:loss = 36.1836, step = 7601 (0.448 sec)
INFO:tensorflow:Saving checkpoints for 7700 into tmp/fashion_mnist/deep/model.ckpt.
INFO:tensorflow:Loss for final step: 25.4458.

Out[2]	<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x7f25ca79df28>

In[9]	
#计算准确度
accuracy_score = deep_classifier.evaluate(input_fn = 						 			input_fn(DATA_SETS.test, batch_size = 100, num_epochs = 1, 											  shuffle = False))['accuracy']

INFO:tensorflow:Starting evaluation at 2018-03-07-08:00:43
INFO:tensorflow:Restoring parameters from tmp/fashion_mnist/deep/model.ckpt-7700
INFO:tensorflow:Finished evaluation at 2018-03-07-08:00:43
INFO:tensorflow:Saving dict for global step 7700: accuracy = 0.8743, average_loss = 0.355154, global_step = 7700, loss = 35.5154


In[10]	
#选取4000到40045张图片进行一次实际的预测
#预测输入函数
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
        	               x = {'pixels': DATA_SETS.test.images[4000:4005]},
            	           batch_size = 1,
                	       num_epochs = 1,
                    	   shuffle = False)
In[11]	predictions = deep_classifier.predict(input_fn = predict_input_fn)
In[12]	for p in predictions:
    		print('Prediction: {} with probabilities {}\n'.format(p['classes'], 														   p['probabilities']))

In[13]	print('Expect correct answer values is: {}'.format
(DATA_SETS.test.labels[4000:4005]))

INFO:tensorflow:Restoring parameters from tmp/fashion_mnist/deep/model.ckpt-7700
Prediction: [b'0'] with probabilities [  8.85963202e-01   2.56781423e-05   4.29330952e-03   3.06227105e-03
   5.28713455e-04   2.37491005e-09   1.06063619e-01   2.90737194e-08
   6.30379072e-05   1.69134964e-07]

Prediction: [b'0'] with probabilities [  9.96059775e-01   3.01950354e-09   1.58900730e-05   1.49504587e-04
   3.14754004e-07   1.30554459e-16   3.77456401e-03   1.15441985e-13
   3.15555271e-09   2.77197384e-12]

Prediction: [b'9'] with probabilities [  2.05979784e-08   4.87241074e-08   6.31761395e-07   2.98880963e-07
   5.09397609e-08   9.14796401e-05   4.43378804e-06   3.86087340e-03
   8.52311146e-08   9.96042013e-01]

Prediction: [b'8'] with probabilities [  8.03876731e-07   2.79696058e-11   9.39605727e-09   3.52724818e-08
   4.37717148e-07   4.22516162e-13   3.26979011e-06   4.28372060e-11
   9.99995470e-01   5.50698827e-11]

Prediction: [b'5'] with probabilities [  1.81485848e-05   9.46296313e-06   1.18957887e-05   5.37845017e-06
   1.25248152e-06   9.98875082e-01   6.31277635e-06   6.29263523e-04
   2.42290087e-04   2.00855589e-04]
Expect correct answer values is: [0 0 9 8 5]



In[14]	
#使用matplotlib.pyplot绘制图片
import matplotlib.pyplot as plt
#jupyter notebook魔法函数 其功能是将matplotlib绘制的图片直接嵌入到	#notebook里
%matplotlib inline

for i in range(4000, 4005):
 		   sample = np.reshape(DATA_SETS.test.images[i], (28, 28))
  		   plt.figure()
 		   plt.title('Label: {}'.format(DATA_SETS.test.labels[i]))
 		   plt.imshow(sample, 'gray')

这里写图片描述

以上5张图片是使用深度分类器实际进行的5次预测,你可以看到5件衣服以及顶部使用数字标明的衣服种类。实际标签依次为0、0、9、8、5,我们的预测结果为0、0、9、8、5。就实际进行的一次预测的结果来说,预测的4000到4004一共5张图片的衣服类型与它们的标签完全一致,100%的成功了。但是针对test数据集进行整体预测的结果进行评估,线性分类器的准确度为84.46%,而深度分类器的准确度为87.43%,很明显深度分类器的准确度高于线性分类器。 事实上,深度分类器的hidden_units参数对预测结果的准确度有着莫大的影响。该参数指定使用的深度神经网络使用几层hidden layer以及每个layer有几个神经元。本例使用的[100, 75, 50]即3层,第一层hidden layer有100个神经元,第二层有75个,第三层有50个。因为该参数是python的list,所以可以任意指定。你可以尝试改变该参数以取得更高的准确率。我将在下一个例子里使用tensorboard详细说明训练过程,以及参数将对训练结果造成怎样的影响。

参考资料:
[7] THE MNIST DATABASE of handwritten digits http://yann.lecun.com/exdb/mnist/
[8] TensorFlow Python API r1.6 : https://www.tensorflow.org/api_guides/python/summary#Generation_of_Summaries

机器学习遇上时尚潮流 - 使用 MNIST 数据集训练模型
https://zhuanlan.zhihu.com/p/32710865

posted @ 2018-04-24 17:16  从流域到海域  阅读(67)  评论(0编辑  收藏  举报