CNN训练可视化特征图(tensorflow2.x实现)
CNN训练可视化(tensorflow2.x实现)
原理介绍
卷积层由多个卷积核组成,可以将每个卷积核视为一种特征提取方式。当一个卷积核处理图像数据后,卷积核会提取图中的特征,形成新的特征图。
实例化VGG16
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16
model = VGG16(weights='imagenet')
model.summary()
加载图片并进行预处理
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
from numpy as np
img_path = '0a0c3edd8a0bb0f5b4d6466c9a1af3c6-0.jpg'
#缩放
img = image.load_img(img_path,target_size=(224,224))
# #显示图片
# plt.imshow(img)
# plt.show()
x = image.img_to_array(img)
#维度扩展
x = np.expand_dims(x,axis=0)
#图片预处理
x = preprocess_input(x)
加载图片
预测图片
preds = model.predict(x)
print('predicted', decode_predictions(preds, top=3)[0])
获取指定层对应的输出
layer_names = ['block1_conv1','block3_conv1','block5_conv1']
#获取指定层的输出
layer_outputs = [model.get_layer(layer_name).output for layer_name in layer_names]
#获得模型指定层的输出
activation_model = keras.models.Model(inputs=model.input,outputs=layer_outputs)
#获得输出
activations = activation_model.predict(x)
first_layer_activation = activations[0]
plt.matshow(first_layer_activation[0,:,:,1],cmap='viridis')
plt.show()
可视化CNN训练过程
import numpy as np
images_per_row = 8
for layer_name,layer_activation in zip(layer_names,activations):
#获取卷积核的个数
n_features = layer_activation.shape[-1]
#特征图的形状(1,size,size,n_features)
size = layer_activation.shape[1]
#n_cols = n_features // images_per_row
n_cols = 8
display_grid = np.zeros((size * n_cols,images_per_row * size))
for col in range(n_cols):
for row in range(images_per_row):
channel_image = layer_activation[0,:,:,col * images_per_row + row]
#归一化
channel_image -= channel_image.mean()
channel_image /= channel_image.std()
channel_image *= 128
channel_image += 128
channel_image = np.clip(channel_image,0,255).astype(np.uint8)
display_grid[col * size : (col + 1) * size,row * size: (row + 1) * size] = channel_image
scale = 1. / size
plt.figure(figsize=(scale * display_grid.shape[1],scale * display_grid.shape[0]))
plt.title(layer_name)
plt.grid(False)
plt.imshow(display_grid,aspect='auto',cmap='viridis')
plt.show()