代码笔记26 pytorch复现pointnet
1
浅浅记录一下model的复现,之后做好完整的工程放到github上
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
2
import torch.nn as nn
import torch
import numpy as np
class tnet(nn.Module):
def __init__(self, inplanes: int):
super(tnet, self).__init__()
self.k = inplanes
# conv layers in T-net
self.relu = nn.ReLU(inplace=True)
self.tconv1 = nn.Conv1d(in_channels=inplanes, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.tconv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm1d(128)
self.tconv3 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm1d(1024)
# fc layers in T-net
self.tfc1 = nn.Linear(in_features=1024, out_features=512, bias=False)
self.bnf1 = nn.BatchNorm1d(512)
self.tfc2 = nn.Linear(in_features=512, out_features=256, bias=False)
self.bnf2 = nn.BatchNorm1d(256)
self.tfc3 = nn.Linear(in_features=256, out_features=inplanes ** 2, bias=False)
def forward(self, x):
# input size supposed to be (Batch, Numbers, Channels)
B, C, N = x.size()
assert (C == self.k), "input size is not suitable for the T-Net model!"
# conv operations
x1 = self.relu(self.bn1(self.tconv1(x)))
x2 = self.relu(self.bn2(self.tconv2(x1)))
x3 = self.bn3(self.tconv3(x2))
# maxpool operation for global descriptors
maxpool_x, _ = torch.max(x3, dim=2)
# fc operations
x4 = self.relu(self.bnf1(self.tfc1(maxpool_x)))
x5 = self.relu(self.bnf2(self.tfc2(x4)))
x6 = self.tfc3(x5)
# reshape from (B, k**) to transform matrix (B, k, k)
trans_matrix = torch.reshape(x6, (B, self.k, self.k))
# the identity matrix
iden_matrix = torch.from_numpy(np.eye(self.k).astype(np.float32)).repeat(B, 1, 1)
# output the multipy results
out = torch.matmul(trans_matrix + iden_matrix, x)
return out
class pointnet_encoder(nn.Module):
def __init__(self):
super(pointnet_encoder, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.trans1 = tnet(inplanes=3)
self.econv1 = nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.trans2 = tnet(inplanes=64)
self.econv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm1d(128)
self.econv3 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm1d(1024)
def forward(self, x):
# change data type
x = x.permute(0, 2, 1)
# stage1 3*3 transform
x1 = self.trans1(x)
# stage2 shared MLP
x2 = self.relu(self.bn1(self.econv1(x1)))
# stage3 64*64 transform
x3 = self.trans2(x2)
# stage4 64-128-1024 shared MLPs
x4 = self.relu(self.bn2(self.econv2(x3)))
x5 = self.relu(self.bn3(self.econv3(x4)))
glb = x5.permute(0, 2, 1)
glb = torch.max(glb, dim=1)[0]
seg = x3.permute(0, 2, 1)
return seg, glb
class pointnet_seg(nn.Module):
def __init__(self, num_classes):
super(pointnet_seg, self).__init__()
self.encoder = pointnet_encoder()
self.relu = nn.ReLU(inplace=True)
self.sconv1 = nn.Conv1d(in_channels=1088, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm1d(512)
self.sconv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm1d(256)
self.sconv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm1d(128)
self.finalconv = nn.Conv1d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0,
bias=True)
def forward(self, x):
# concatenate the global features and segmentation features
B, N, C = x.size()
seg_feat, glb_feat = self.encoder(x)
glb_feat = glb_feat.unsqueeze(1).repeat(1, N, 1)
cmb_feat = torch.cat([seg_feat, glb_feat], dim=2)
seg_x = cmb_feat.permute(0, 2, 1)
x1 = self.relu(self.bn1(self.sconv1(seg_x)))
x2 = self.relu(self.bn2(self.sconv2(x1)))
x3 = self.relu(self.bn3(self.sconv3(x2)))
score = self.finalconv(x3).permute(0, 2, 1)
return score
def main():
model = pointnet_seg(num_classes=13)
points = torch.randn([10, 100, 3])
score = model(points)
print("score shape is {}".format(score.size()))
for name, para in model.state_dict(keep_vars=True).items():
print(name, para.shape, para.requires_grad)
if __name__ == "__main__":
main()