CoordConv实现

import torch
import torch.nn as nn
'''
An alternative implementation for PyTorch with auto-infering the x-y dimensions.
paper: An intriguing failing of convolutional neural networks and the CoordConv solution

https://zhuanlan.zhihu.com/p/443583240
https://blog.csdn.net/oYeZhou/article/details/116717210

'''
class AddCoords(nn.Module):

    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, ins_feat):
        """
        Args:
            x: shape(batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = ins_feat.size()
        # 生成从-1到1的线性值
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
        y, x = torch.meshgrid(y_range, x_range) # 生成二维坐标网格
        y = y.expand([ins_feat.shape[0], 1, -1, -1]) # 扩充到和ins_feat相同维度
        x = x.expand([ins_feat.shape[0], 1, -1, -1])
        coord_feat = torch.cat([x, y], 1) # 位置特征
        ret = torch.cat([ins_feat, coord_feat], 1) # concatnate一起作为下一个卷积的输入
        if self.with_r:
            rr = torch.sqrt(torch.pow(x - 0.5, 2) + torch.pow(y - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)
        return ret


class CoordConv(nn.Module):
    def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
        super().__init__()
        self.addcoords = AddCoords(with_r=with_r)
        in_size = in_channels+2
        if with_r:
            in_size += 1
        self.conv = nn.Conv2d(in_size, out_channels, **kwargs)

    def forward(self, x):
        ret = self.addcoords(x)
        ret = self.conv(ret)
        return ret

 

posted @ 2023-02-07 22:31  dangxusheng  阅读(83)  评论(0编辑  收藏  举报