Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation

Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation

CCNet: Criss-Cross Attention for Semantic Segmentation

  • 获得特征图X之后,应用卷积得到一个降维的特征图H并将其喂入十字交叉注意模块CCA得到新的特征图H'。H'仅仅继承了水平和竖直方向的上下文信息还不足以进行语义分割。为了获得更丰富更密集的上下文信息,将特征图H'再次喂入注意模块中并得到特征图H''。这时H''的每个位置都继承了所有像素的信息。称递归结构为递归十字交叉注意模块RCCA。

  • 局部特征图H(C × W × H)采用1 × 1卷积降维得到特征图Q和K(C' × W × H),其中C'<C。从Q中的某一个位置u,u其实就是一个像素,从Q中取出这个u,那个Qu其实就是一个通道维向量C'。同时,从K中与u相同位置处取出u所在的一行和一列共(H+W-1)个像素,即为特征向量Ωu,Ωu形状如下图(u就是十字中心位置):

  • ccnet_model.py

copy
import torch import torch.nn as nn import math import torch.utils.model_zoo as model_zoo from torch.nn import functional as F from torch.nn import Softmax model_urls = { 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', } BatchNorm2d = nn.BatchNorm2d def INF(B,H,W): return -torch.diag(torch.tensor(float("inf")).cuda(1).repeat(H),0).unsqueeze(0).repeat(B*W,1,1) class CC_module(nn.Module): def __init__(self,in_dim): super(CC_module, self).__init__() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): m_batchsize, _, height, width = x.size() proj_query = self.query_conv(x) proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) proj_key = self.key_conv(x) proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) proj_value = self.value_conv(x) proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width) concate = self.softmax(torch.cat([energy_H, energy_W], 3)) #concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float() att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) #print(concate) #print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) #print(out_H.size(),out_W.size()) return self.gamma*(out_H + out_W) + x class RCCAModule(nn.Module): def __init__(self, in_channels, out_channels, num_classes): super(RCCAModule, self).__init__() inter_channels = in_channels // 4 self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), BatchNorm2d(inter_channels),nn.ReLU(inplace=False)) self.cca = CC_module(inter_channels) self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), BatchNorm2d(inter_channels),nn.ReLU(inplace=False)) self.bottleneck = nn.Sequential( nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False), BatchNorm2d(out_channels),nn.ReLU(inplace=False), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) def forward(self, x, recurrence=2): output = self.conva(x) for i in range(recurrence): output = self.cca(output) output = self.convb(output) output = self.bottleneck(torch.cat([x, output], 1)) return output # 基于F.conv2d自己建的Conv2d类,其中F.conv2d仅仅只是卷积操作,而nn.Conv2d是卷积层类 class Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, x): # return super(Conv2d, self).forward(x) weight = self.weight weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) weight = weight - weight_mean std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 weight = weight / std.expand_as(weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) # ResNet中的block类型,指的是1x1,3x3,1x1三种卷积混合的模式,采用先降维再升维,降低计算复杂度 class Bottleneck(nn.Module): expansion = 4 # 在block最后升维的倍数,恢复原来的通道数 # 这里的planes不再是网络中的输出通道数,而是在block中降维的输出通道数 def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None): super(Bottleneck, self).__init__() self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False) self.bn1 = norm(planes) self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) self.bn2 = norm(planes) self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = norm(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) # 此处的downsample利用1x1卷积来改变通道数,使残差块的连接可以直接相加 if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out # deeplabv3的ASPP模块 class ASPP(nn.Module): def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): super(ASPP, self).__init__() self._C = C # 进入aspp的通道数 self._depth = depth # filter的个数 self._num_classes = num_classes self.global_pooling = nn.AdaptiveAvgPool2d(1) self.relu = nn.ReLU(inplace=True) # 第一个1x1卷积 self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False) # aspp中的空洞卷积,rate=6,12,18 self.aspp2 = conv(C, depth, kernel_size=3, stride=1, dilation=int(6*mult), padding=int(6*mult), bias=False) self.aspp3 = conv(C, depth, kernel_size=3, stride=1, dilation=int(12*mult), padding=int(12*mult), bias=False) self.aspp4 = conv(C, depth, kernel_size=3, stride=1, dilation=int(18*mult), padding=int(18*mult), bias=False) # 对最后一个特征图进行全局平均池化,再feed给256个1x1的卷积核,都带BN self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False) self.aspp1_bn = norm(depth, momentum) self.aspp2_bn = norm(depth, momentum) self.aspp3_bn = norm(depth, momentum) self.aspp4_bn = norm(depth, momentum) self.aspp5_bn = norm(depth, momentum) # 先上采样双线性插值得到想要的维度,再进入下面的conv self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1, bias=False) self.bn2 = norm(depth, momentum) # 打分分类 self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) def forward(self, x): x1 = self.aspp1(x) x1 = self.aspp1_bn(x1) x1 = self.relu(x1) x2 = self.aspp2(x) x2 = self.aspp2_bn(x2) x2 = self.relu(x2) x3 = self.aspp3(x) x3 = self.aspp3_bn(x3) x3 = self.relu(x3) x4 = self.aspp4(x) x4 = self.aspp4_bn(x4) x4 = self.relu(x4) x5 = self.global_pooling(x) x5 = self.aspp5(x5) x5 = self.aspp5_bn(x5) x5 = self.relu(x5) # 上采样:双线性插值使x得到想要的维度 x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)(x5) # 经过aspp之后,concat之后通道数变为了5倍 x = torch.cat((x1, x2, x3, x4, x5), 1) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.conv3(x) return x # 基于ResNet的deeplabv3 class ResNet(nn.Module): def __init__(self, block, block_num, num_classes, num_groups=None, weight_std=False, beta=False, pretrained=False): self.inplanes = 64 # 控制残差块的输入通道数 planes:输出通道数 # nn.BatchNorm2d和nn.GroupNorm两种不同的归一化方法 self.norm = nn.BatchNorm2d self.conv = Conv2d if weight_std else nn.Conv2d super(ResNet, self).__init__() if not beta: # 整个ResNet的第一个conv self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3, bias=False) else: # 第一个残差模块的conv self.conv1 = nn.Sequential( self.conv(3, 64, 3, stride=2, padding=1, bias=False), self.conv(64, 64, 3, stride=1, padding=1, bias=False), self.conv(64, 64, 3, stride=1, padding=1, bias=False)) self.bn1 = self.norm(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 建立残差块部分 self.layer1 = self._make_layer(block, 64, block_num[0]) self.layer2 = self._make_layer(block, 128, block_num[1], stride=2) self.layer3 = self._make_layer(block, 256, block_num[2], stride=2) # block4开始为dilation空洞卷积 self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, dilation=2) # ccnet模块 self.ccnet = RCCAModule(512 * block.expansion, 512, num_classes) # ccnet最后融合的一个特征 self.dsn = nn.Sequential( nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), self.norm(512),nn.ReLU(inplace=False), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) # aspp,512 * block.expansion是经过残差模块的输出通道数 self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm) # 模仿aspp进行danet和aspp的cat之后,进行conv+norm等操作 self.conv2 = self.conv(num_classes * 3, num_classes, kernel_size=1, stride=1, bias=False) # 遍历模型进行初始化 for m in self.modules(): if isinstance(m, self.conv): #isinstance:m类型判断 若当前组件为 conv n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) #正太分布初始化 elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): #若为batchnorm m.weight.data.fill_(1) #weight为1 m.bias.data.zero_() #bias为0 if pretrained: self._load_pretrained_model() def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None # stride!=1 代表后续残差块中有stride=2,尺寸大小改变,所以第一个残差块中的stride也该用来修改尺寸 if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( self.conv(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False), self.norm(planes * block.expansion), ) # laysers 存放产生的残差块,最后根据此列表进行生成网络 layers = [] # 在多个残差块中,只有第一个残差块的输入输出通道不一致,所以先单独添加带downsample的block layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm)) return nn.Sequential(*layers) def forward(self, x): # x.shape:[batch_size, channels, H, w] size = (x.shape[2], x.shape[3]) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x_dsn = self.dsn(x) #print('x_dsn:',x_dsn.shape) x_dsn: torch.Size([4, 4, 20, 20]) x_res = self.layer4(x) #print('x_res:',x_res.shape) x_res: torch.Size([4, 2048, 20, 20]) # ASPP x_aspp = self.aspp(x_res) #print('x_aspp:',x_aspp.shape) x_aspp: torch.Size([4, 4, 20, 20]) # ccnet x_ccnet_1 = self.ccnet(x_res, 2) #print('x_ccnet_1:',x_ccnet_1.shape) x_ccnet_1: torch.Size([4, 4, 20, 20]) x_ccnet_2 = torch.cat([x_ccnet_1, x_dsn],1) #print('x_ccnet_2:',x_ccnet_2.shape) x_ccnet_2: torch.Size([4, 8, 20, 20]) out = torch.cat((x_aspp, x_ccnet_2),1) #print('out cat shape', out.shape) out cat shape torch.Size([4, 12, 20, 20]) out = self.conv2(out) out = nn.Upsample(size, mode='bilinear', align_corners=True)(out) return out def _load_pretrained_model(self): pretrain_dict = model_zoo.load_url(model_urls['resnet152']) model_dict = {} state_dict = self.state_dict() for k, v in pretrain_dict.items(): if k in state_dict: model_dict[k] = v state_dict.update(model_dict) self.load_state_dict(state_dict) # 实例化模型 def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ # [3,4,6,3]对应block_num,残差块的数量 model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=4, **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) return model if __name__ == "__main__": device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') model = resnet152() model = model.to(device) x = torch.rand((4,3,320,320)) #x = torch.tensor(x, dtype = torch.float) x = x.to(device) print(x.shape) print('====================') output = model(x) print('====================') print(output.shape)
posted @   梁君牧  阅读(109)  评论(0编辑  收藏  举报
历史上的今天:
2021-08-13 【机器学习】机器学习算法整理
2021-08-13 【算法题】刷刷OJ题题题
点击右上角即可分享
微信分享提示
🚀