CoordConv实现
import torch import torch.nn as nn ''' An alternative implementation for PyTorch with auto-infering the x-y dimensions. paper: An intriguing failing of convolutional neural networks and the CoordConv solution https://zhuanlan.zhihu.com/p/443583240 https://blog.csdn.net/oYeZhou/article/details/116717210 ''' class AddCoords(nn.Module): def __init__(self, with_r=False): super().__init__() self.with_r = with_r def forward(self, ins_feat): """ Args: x: shape(batch, channel, x_dim, y_dim) """ batch_size, _, x_dim, y_dim = ins_feat.size() # 生成从-1到1的线性值 x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device) y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device) y, x = torch.meshgrid(y_range, x_range) # 生成二维坐标网格 y = y.expand([ins_feat.shape[0], 1, -1, -1]) # 扩充到和ins_feat相同维度 x = x.expand([ins_feat.shape[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) # 位置特征 ret = torch.cat([ins_feat, coord_feat], 1) # concatnate一起作为下一个卷积的输入 if self.with_r: rr = torch.sqrt(torch.pow(x - 0.5, 2) + torch.pow(y - 0.5, 2)) ret = torch.cat([ret, rr], dim=1) return ret class CoordConv(nn.Module): def __init__(self, in_channels, out_channels, with_r=False, **kwargs): super().__init__() self.addcoords = AddCoords(with_r=with_r) in_size = in_channels+2 if with_r: in_size += 1 self.conv = nn.Conv2d(in_size, out_channels, **kwargs) def forward(self, x): ret = self.addcoords(x) ret = self.conv(ret) return ret