test1
class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, bn=False): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn = bn if bn: self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) if bn: self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), # nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): if self.bn: out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) else: out = F.relu(self.conv1(x)) out = F.relu(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out