随笔 - 7  文章 - 0  评论 - 0  阅读 - 1072

神经网络特征图的可视化

神经网络的可视化过程

特征图可视化

""""
神经网络的可视化过程
"""
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np
import cv2
from torchsummary import summary
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
nn.MaxPool2d(2),
nn.MaxPool2d(2),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
path = "2.png"
trans = transforms.Compose([transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
x = Image.open(path).convert("RGB")
x = trans(x)
x = torch.unsqueeze(x, 0) # 填充一维
modol = Test()
y = modol(x) # y = [1, 8, 112, 112]([N, C, W, H])
#特征图的可视化
def map_feature(img, images_per_row):
tt = img.detach().numpy()
layer_names = []
for layer in modol._modules.items():
layer_names.append(layer[0])
for layer_name, layer_activation in zip(layer_names, tt):
n_features = layer_activation.shape[0] # 8
size = layer_activation.shape[1] # 112
n_cols = n_features // images_per_row # 2
display_grid = np.zeros((size * n_cols, images_per_row * size)) # [112*2, 112*4]
for col in range(n_cols):
for row in range(images_per_row):
channel_image = layer_activation[col * images_per_row + row, :, :]
channel_image -= channel_image.mean()
channel_image /= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype('uint8')
display_grid[col * size: (col + 1) * size, row * size: (row + 1) * size] = channel_image
scale = 1. / size
plt.figure(figsize=(scale * display_grid.shape[1],
scale * display_grid.shape[0]))
plt.title(layer_name)
plt.grid(False)
plt.imshow(display_grid, aspect='auto', cmap='viridis')
plt.savefig(layer_name + ".png")
plt.show()
map_feature(img = y, images_per_row = 8)

代码无注释,哪句有问题,欢迎留言,顺便给个关注。

posted on   钱了个钱  阅读(174)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
< 2025年2月 >
26 27 28 29 30 31 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 1
2 3 4 5 6 7 8

点击右上角即可分享
微信分享提示