CNN每层卷积结果视觉展示(3Dircadb肝脏数据为例)

试着展示了肝脏每层卷积之后的结果。代码如下:

import torch
import torch.nn as nn
import SimpleITK as sitk
import numpy as np

def change_indenty(ct):
    ct[ct < 40] = 40
    ct[ct > 400] = 400
    return ct

class Dialte(nn.Module):
    def __init__(self, num):
        super(Dialte, self).__init__()
        self.num= num
        self.act = nn.ReLU(inplace=False)
        self.norm = nn.BatchNorm3d
        self.conv1 = nn.Sequential(
            nn.Conv3d(1,1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )

    def forward(self, x):
        for i in range(self.num):
            x = self.conv1(x)
        return x

if __name__ == "__main__":
    path = r"D:\myProject\HDC_vessel_seg\datasets\nii\image_2.nii"
    image = sitk.ReadImage(path)
    img_num = sitk.GetArrayFromImage(image)
    img_num = change_indenty(img_num)
    img_num = np.expand_dims(np.expand_dims(img_num, axis=0), axis=0).astype(np.float32)
    img_num = torch.from_numpy(img_num)
    print(img_num.shape)

    # image = torch.randn(1*3*8*8*8).reshape(1,3,8,8,8)
    model = Dialte(num=5)
    x = model(img_num)
    # x = img_num  # 展示原始图像
    x = x[0,0,...]
    x = x.cpu().data.numpy()
    print("x",x.shape)
    predict_seg = sitk.GetImageFromArray(x)
    predict_seg.SetSpacing(image.GetSpacing())
    predict_seg.SetOrigin(image.GetOrigin())
    predict_seg.SetDirection(image.GetDirection())
    # sitk.WriteImage(predict_seg, path.replace("vessel", "dialte"))
    sitk.WriteImage(predict_seg, path.replace("image", "pre_image"))

  

  结果:

原始

 

第一层:

 

第二层:

 

第三层:

 

第四层:

 

第五层:

 

posted @ 2022-05-16 17:13  九叶草  阅读(375)  评论(0编辑  收藏  举报