Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation
Thesis-CCNet: Criss-Cross Attention for Semantic Segmentation
CCNet: Criss-Cross Attention for Semantic Segmentation
- 获得特征图X之后,应用卷积得到一个降维的特征图H并将其喂入十字交叉注意模块CCA得到新的特征图H'。H'仅仅继承了水平和竖直方向的上下文信息还不足以进行语义分割。为了获得更丰富更密集的上下文信息,将特征图H'再次喂入注意模块中并得到特征图H''。这时H''的每个位置都继承了所有像素的信息。称递归结构为递归十字交叉注意模块RCCA。
-
局部特征图H(C \(\times\) W \(\times\) H)采用1 \(\times\) 1卷积降维得到特征图Q和K(C' \(\times\) W \(\times\) H),其中C'<C。从Q中的某一个位置u,u其实就是一个像素,从Q中取出这个u,那个Qu其实就是一个通道维向量C'。同时,从K中与u相同位置处取出u所在的一行和一列共(H+W-1)个像素,即为特征向量Ωu,Ωu形状如下图(u就是十字中心位置):
-
ccnet_model.py
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
from torch.nn import Softmax
model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
BatchNorm2d = nn.BatchNorm2d
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda(1).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CC_module(nn.Module):
def __init__(self,in_dim):
super(CC_module, self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
proj_value = self.value_conv(x)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
#concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float()
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x
class RCCAModule(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(RCCAModule, self).__init__()
inter_channels = in_channels // 4
self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
BatchNorm2d(inter_channels),nn.ReLU(inplace=False))
self.cca = CC_module(inter_channels)
self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
BatchNorm2d(inter_channels),nn.ReLU(inplace=False))
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
BatchNorm2d(out_channels),nn.ReLU(inplace=False),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
)
def forward(self, x, recurrence=2):
output = self.conva(x)
for i in range(recurrence):
output = self.cca(output)
output = self.convb(output)
output = self.bottleneck(torch.cat([x, output], 1))
return output
# 基于F.conv2d自己建的Conv2d类,其中F.conv2d仅仅只是卷积操作,而nn.Conv2d是卷积层类
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
# return super(Conv2d, self).forward(x)
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# ResNet中的block类型,指的是1x1,3x3,1x1三种卷积混合的模式,采用先降维再升维,降低计算复杂度
class Bottleneck(nn.Module):
expansion = 4 # 在block最后升维的倍数,恢复原来的通道数
# 这里的planes不再是网络中的输出通道数,而是在block中降维的输出通道数
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None):
super(Bottleneck, self).__init__()
self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = norm(planes)
self.conv2 = conv(planes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn2 = norm(planes)
self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
# 此处的downsample利用1x1卷积来改变通道数,使残差块的连接可以直接相加
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# deeplabv3的ASPP模块
class ASPP(nn.Module):
def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1):
super(ASPP, self).__init__()
self._C = C # 进入aspp的通道数
self._depth = depth # filter的个数
self._num_classes = num_classes
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.relu = nn.ReLU(inplace=True)
# 第一个1x1卷积
self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False)
# aspp中的空洞卷积,rate=6,12,18
self.aspp2 = conv(C, depth, kernel_size=3, stride=1,
dilation=int(6*mult), padding=int(6*mult),
bias=False)
self.aspp3 = conv(C, depth, kernel_size=3, stride=1,
dilation=int(12*mult), padding=int(12*mult),
bias=False)
self.aspp4 = conv(C, depth, kernel_size=3, stride=1,
dilation=int(18*mult), padding=int(18*mult),
bias=False)
# 对最后一个特征图进行全局平均池化,再feed给256个1x1的卷积核,都带BN
self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False)
self.aspp1_bn = norm(depth, momentum)
self.aspp2_bn = norm(depth, momentum)
self.aspp3_bn = norm(depth, momentum)
self.aspp4_bn = norm(depth, momentum)
self.aspp5_bn = norm(depth, momentum)
# 先上采样双线性插值得到想要的维度,再进入下面的conv
self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1,
bias=False)
self.bn2 = norm(depth, momentum)
# 打分分类
self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1)
def forward(self, x):
x1 = self.aspp1(x)
x1 = self.aspp1_bn(x1)
x1 = self.relu(x1)
x2 = self.aspp2(x)
x2 = self.aspp2_bn(x2)
x2 = self.relu(x2)
x3 = self.aspp3(x)
x3 = self.aspp3_bn(x3)
x3 = self.relu(x3)
x4 = self.aspp4(x)
x4 = self.aspp4_bn(x4)
x4 = self.relu(x4)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = self.aspp5_bn(x5)
x5 = self.relu(x5)
# 上采样:双线性插值使x得到想要的维度
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
# 经过aspp之后,concat之后通道数变为了5倍
x = torch.cat((x1, x2, x3, x4, x5), 1)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
return x
# 基于ResNet的deeplabv3
class ResNet(nn.Module):
def __init__(self, block, block_num, num_classes, num_groups=None, weight_std=False, beta=False, pretrained=False):
self.inplanes = 64 # 控制残差块的输入通道数 planes:输出通道数
# nn.BatchNorm2d和nn.GroupNorm两种不同的归一化方法
self.norm = nn.BatchNorm2d
self.conv = Conv2d if weight_std else nn.Conv2d
super(ResNet, self).__init__()
if not beta:
# 整个ResNet的第一个conv
self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
else:
# 第一个残差模块的conv
self.conv1 = nn.Sequential(
self.conv(3, 64, 3, stride=2, padding=1, bias=False),
self.conv(64, 64, 3, stride=1, padding=1, bias=False),
self.conv(64, 64, 3, stride=1, padding=1, bias=False))
self.bn1 = self.norm(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 建立残差块部分
self.layer1 = self._make_layer(block, 64, block_num[0])
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
# block4开始为dilation空洞卷积
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, dilation=2)
# ccnet模块
self.ccnet = RCCAModule(512 * block.expansion, 512, num_classes)
# ccnet最后融合的一个特征
self.dsn = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
self.norm(512),nn.ReLU(inplace=False),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
)
# aspp,512 * block.expansion是经过残差模块的输出通道数
self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm)
# 模仿aspp进行danet和aspp的cat之后,进行conv+norm等操作
self.conv2 = self.conv(num_classes * 3, num_classes, kernel_size=1, stride=1, bias=False)
# 遍历模型进行初始化
for m in self.modules():
if isinstance(m, self.conv): #isinstance:m类型判断 若当前组件为 conv
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n)) #正太分布初始化
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): #若为batchnorm
m.weight.data.fill_(1) #weight为1
m.bias.data.zero_() #bias为0
if pretrained:
self._load_pretrained_model()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
# stride!=1 代表后续残差块中有stride=2,尺寸大小改变,所以第一个残差块中的stride也该用来修改尺寸
if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
self.conv(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False),
self.norm(planes * block.expansion),
)
# laysers 存放产生的残差块,最后根据此列表进行生成网络
layers = []
# 在多个残差块中,只有第一个残差块的输入输出通道不一致,所以先单独添加带downsample的block
layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm))
return nn.Sequential(*layers)
def forward(self, x):
# x.shape:[batch_size, channels, H, w]
size = (x.shape[2], x.shape[3])
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x_dsn = self.dsn(x)
#print('x_dsn:',x_dsn.shape) x_dsn: torch.Size([4, 4, 20, 20])
x_res = self.layer4(x)
#print('x_res:',x_res.shape) x_res: torch.Size([4, 2048, 20, 20])
# ASPP
x_aspp = self.aspp(x_res)
#print('x_aspp:',x_aspp.shape) x_aspp: torch.Size([4, 4, 20, 20])
# ccnet
x_ccnet_1 = self.ccnet(x_res, 2)
#print('x_ccnet_1:',x_ccnet_1.shape) x_ccnet_1: torch.Size([4, 4, 20, 20])
x_ccnet_2 = torch.cat([x_ccnet_1, x_dsn],1)
#print('x_ccnet_2:',x_ccnet_2.shape) x_ccnet_2: torch.Size([4, 8, 20, 20])
out = torch.cat((x_aspp, x_ccnet_2),1)
#print('out cat shape', out.shape) out cat shape torch.Size([4, 12, 20, 20])
out = self.conv2(out)
out = nn.Upsample(size, mode='bilinear', align_corners=True)(out)
return out
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
# 实例化模型
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
# [3,4,6,3]对应block_num,残差块的数量
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=4, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
if __name__ == "__main__":
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = resnet152()
model = model.to(device)
x = torch.rand((4,3,320,320))
#x = torch.tensor(x, dtype = torch.float)
x = x.to(device)
print(x.shape)
print('====================')
output = model(x)
print('====================')
print(output.shape)