PointNet++ MSG

msg的训练的巨慢,不过总算是复现出来了,之后搞一下MRG的,或者裸的

 

 

 

关于插值:

  

 

 

注意:最后一次set abstraction 层之后并没有对称函数,因此B * C2 * k的特征向量(其中k为特征向量维数),还保存的一些点的数量C2,因此不是简单的repeat,然后拼接,是先找k近邻

   “PointNet”中是变成了1 * 1024(或者说1024 * 1)

如第一次插值,全局特征有C2个点,局部特征有C1个点,则从C2个点中找出每个局部特征点的k近邻点

获得相应特征,然后将k个紧邻点的特征按欧几里得距离做一个加权和,与C1个点的局部特征拼接到一起,即可完成插值,然后再unit PointNet


 

configuration.py:

import torch.cuda


class config():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset'
    checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint'

    num_epochs = 10
    batch_size = 4
    num_seg = 40
View Code

Model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as Dataloader
import torch.utils.data as data
from configuration import config

cfg = config()

def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S],其中S可以是一个多项式
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points



def farthest_sample(data, sample_n):
    #data = data.transpose(1, 2)
    B, N, C = data.size()
    B_list = [i for i in range(B)]
    #dis标记每个点到采样点集合的距离
    dis = torch.ones((B, N)) * 1e-10
    #初始随机选一个点
    u = torch.randint(0, N, (B, ))
    vis = torch.ones((B, N))
    #vis标记某点是否被选
    vis[B_list, u] = 0
    ret = torch.zeros((B, sample_n), dtype = torch.long)
    for i in range(sample_n):
        ret[:, i] = u
        #取出点u
        cen = data[B_list, u, :].view(B, 1, 3)
        #求所有点到点u的距离
        distance = torch.sum((data - cen) ** 2, -1)
        #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新
        idx = torch.logical_and(dis < distance, vis)
        dis[idx] = distance[idx]
        u = torch.max(dis, -1)[1]
        #如果某点被选
        dis[B_list, u] = 0
        vis[B_list, u] = 0
    return index_points(data, ret)

def square_distance(data, center):
    #input: pointcloud_data center
    #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离
    N = data.size()[1]
    B, sample_n,_ = center.size()
    bl = torch.arange(B, dtype = torch.long)
    nl = torch.arange(N, dtype = torch.long)
    dis = torch.zeros((B, sample_n, N))
    for i in range(sample_n):
        coor = center[:, i, :].view(B, 1, 3)
        value = torch.sum((data - coor) ** 2, -1)
        dis[:, i, :] = value
    return dis


def query_ball_point(data, dis, radius, k):
    #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点
    #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标
    B, n, N = dis.size()
    group = torch.zeros([B, n, N], dtype=torch.long)
    B_list = torch.arange(B, dtype=torch.long)
    group[:, :, :] = torch.arange(N, dtype=torch.long)
    # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁
    # 这里存的是下标
    group[dis > radius ** 2] = N
    group = group.sort(dim=-1)[0][:, :, : k]
    # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧
    idx = group == N
    defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k)
    group[idx] = defa[idx]

    return group


# def cal_coor(data, group):
#     B, n, k = group.size()
#     B_list = torch.arange(B, dtype = torch.long)
#     res = torch.zeros([B, n, k, 3])
#     for i in range(n):
#         for j in range(k):
#             res[:, i, j, :] = data[B_list, group[:, i, j]]
#     return res

def sample_and_group(xyz, point, radius, nsample, dis):
    group = query_ball_point(xyz, dis, radius, nsample)
    new_xyz = index_points(xyz, group) # B * n * k * 3
    new_point = index_points(point, group) # B * n * k * C
    return new_xyz, new_point




class PointNetSetAbstractionMsg(nn.Module):
    #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表
    def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None):
        super().__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel
            for j in range(len(mlp_list[i])):
                convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积
                bns.append(nn.BatchNorm2d(mlp_list[i][j]))
                last_channel = mlp_list[i][j]
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)
    #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积
    #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定
    def forward(self, xyz, point):
        xyz = xyz.transpose(1, 2)
        if point is not None :
            point = point.transpose(1, 2)
        if self.npoint is not None:
            center = farthest_sample(xyz, self.npoint)
            dis = square_distance(xyz, center)
            point_list = []
            for i in range(len(self.radius_list)):
                new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis)
                new_point = new_point.permute(0, 3, 2, 1)
                for j in range(len(self.conv_blocks[i])):
                    conv = self.conv_blocks[i][j]
                    bn = self.bn_blocks[i][j]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim=2)[0]  # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组
                point_list.append(new_point)
            new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合
            center = center.transpose(2, 1)
            return center, new_point
        else:
            center = torch.zeros(xyz.size(0), 1, xyz.size(2))
            new_point = point.view(point.size(0), 1, point.size(1), point.size(2))
            new_point = new_point.permute(0, 3, 2, 1)
            for j in range(len(self.conv_blocks[0])):
                conv = self.conv_blocks[0][j]
                bn = self.bn_blocks[0][j]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = center.transpose(1, 2)

            return center, new_point

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False):
        super(PointNetSetAbstraction, self).__init__()
        self.group_all = group_all
        if self.group_all == False:
            self.npoint = npoint
            self.radius = radius
            self.nsample = nsample
        self.conv_block = nn.ModuleList()
        self.bn_block = nn.ModuleList()
        last_channel = in_channel
        for i in range(len(mlp)):
            self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1))
            self.bn_block.append(nn.BatchNorm2d(mlp[i]))
            last_channel = mlp[i]
    def forward(self, xyz, point):
        xyz = xyz.transpose(1, 2)
        point = point.transpose(1 ,2)
        if self.group_all == False:
            center = farthest_sample(xyz, self.npoint)
            dis = square_distance(xyz, center)
            new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis)
            new_point = new_point.permute(0, 3, 2, 1)
            for i in range(len(self.conv_block)):
                conv = self.conv_block[i]
                bn = self.bn_block[i]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = center.transpose(1, 2)
            return center, new_point
        else:
            new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1)
            for i in range(len(self.conv_block)):
                conv = self.conv_block[i]
                bn = self.bn_block[i]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = torch.zeros(xyz.size(0), 1, xyz.size(2))
            center = center.transpose(1, 2)
            return center, new_point




class PointNetFeaturePropagation(nn.Module):
    def __init__(self, inchannel, mlp_list):
        super().__init__()
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        last_channel = inchannel
        for i in range(len(mlp_list)):
            self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1))
            self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i]))
            last_channel = mlp_list[i]
    def forward(self, xyz1, xyz2, point1, point2):
        #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet
        xyz1 = xyz1.transpose(1, 2)
        xyz2 = xyz2.transpose(1, 2)
        point2 = point2.transpose(1, 2)
        B, N, _ = xyz1.size()
        _, n, _ = xyz2.size()
        if n == 1:
            point2 = point2.repeat(1, N, 1)
        else:
            dis, idx = square_distance(xyz2, xyz1).sort(dim = -1)
            dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方
            idx = idx[:, :, 1 : 4] #三个最近的点的原坐标
            feature = index_points(point2, idx) # B * N * 3 * D
            w = 1.0 / (dis + 1e-8) # B * N * 3
            ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1
            ratio = w / ep_w # B * N * 3
            #每个系数都 * 相应的特征,并将每一维的三个特征加和
            point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D
        if point1 is not None:
            point1 = point1.transpose(1, 2)
            point = torch.cat([point1, point2], dim = -1)
        else:
            point = point2
        point = point.transpose(1, 2)
        for i in range(len(self.conv_blocks)):
            conv = self.conv_blocks[i]
            bn = self.bn_blocks[i]
            point = F.relu(bn(conv(point)))
        return point


class PointNet_add_Msg(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]])
        self.fp1 = PointNetFeaturePropagation(1664, [256, 256])
        self.fp2 = PointNetFeaturePropagation(576, [256, 128])
        self.fp3 = PointNetFeaturePropagation(131, [128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.conv2 = nn.Conv1d(128, cfg.num_seg, 1)

        self.bn1 = nn.BatchNorm1d(128)

    def forward(self, xyz):
        xyz1, point1 = self.sa1(xyz, xyz)
        xyz2, point2 = self.sa2(xyz1, point1)
        xyz3, point3 = self.sa3(xyz2, point2)
        point2 = self.fp1(xyz2, xyz3, point2, point3)
        point1 = self.fp2(xyz1, xyz2, point1, point2)
        point0 = self.fp3(xyz, xyz1, xyz, point1)
        point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
        point = self.conv2(point)
        point = point.transpose(1, 2).contiguous()
        point = point.view(-1, cfg.num_seg)
        return point
View Code

 

 

train.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from DataSet import Dataset
from Model import PointNet_add
from configuration import config
import torch.utils.data.dataloader as DataLoader
from tensorboardX import SummaryWriter
import os


cfg = config()

if __name__ == '__main__':

    model = PointNet_add()
    model.to(cfg.device)
    dataset = Dataset(cfg.dataset_root)
    dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    loss = nn.CrossEntropyLoss()
    tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log'))
    model.train()
    for epoch in range(cfg.num_epochs):
        total_true = 0
        total_loss = 0
        cnt = 0
        for xyz, label in tqdm(dataloader):
            optimizer.zero_grad()
            output = model(xyz)
            label = label.view(-1, 1)[:, 0]
            loss_value = loss(output, label)
            loss_value.backward()
            optimizer.step()
            pred = torch.max(output, -1)[1]
            total_true += torch.sum(pred == label)
            total_loss += loss_value.item()
            cnt += 1
        mean_loss = total_loss / float(cnt)
        accuracy = total_true / float(len(dataset))
        tbwrite.add_scalar('Loss', mean_loss, epoch)
        tbwrite.add_scalar('Accuracy', accuracy, epoch)
        print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy))
        if (epoch + 1) % cfg.num_epochs == 0:
            state = {
                'model': model.state_dict()
            }
            torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch)))
View Code

 

DataSet.py:

import numpy as np
import torch.utils.data as data
import os
import random
import torch



class Dataset(data.Dataset):
    def __init__(self, root):
        super().__init__()
        self.root = root
        data_list = os.listdir(os.path.join(root, 'points'))
        label_list = os.listdir(os.path.join(root, 'points_label'))
        self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0]))
        self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0]))

    def __getitem__(self, index):
        self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index]))
        self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index]))

        #采样2500个点,如果不够,则随机抽样补全
        if self.data.shape[0] >= 2500:
            sample_list = random.sample(range(self.data.shape[0]), 2500)
            self.data = self.data[sample_list, :]
            self.label = self.label[sample_list]
        else:
            sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0])
            dup_data = self.data[sample_list, :]
            dup_label = self.label[sample_list]
            self.data = np.concatenate([self.data, dup_data], 0)
            self.label = np.concatenate([self.label, dup_label], 0)

        self.label = torch.tensor(self.label)
        self.label = self.label.type(torch.LongTensor)
        self.data = torch.tensor(self.data.T)
        #label要是Longtensor,data要是float32
        self.data = self.data.to(torch.float32)

        return self.data, self.label

    def __len__(self):
        return len(self.label_list)
View Code

 

 

 

 

以下是分类和分割的结合代码,有普通的和msg的

Dataset.py:

import numpy as np
import torch.utils.data as data
import os
import random
import torch
import h5py



class Dataset(data.Dataset):
    def __init__(self, root):
        super().__init__()
        self.root = root
        data_list = os.listdir(os.path.join(root, 'points'))
        label_list = os.listdir(os.path.join(root, 'points_label'))
        self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0]))
        self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0]))

    def __getitem__(self, index):
        self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index]))
        self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index]))
        #采样2500个点,如果不够,则随机抽样补全
        if self.data.shape[0] >= 2500:
            sample_list = random.sample(range(self.data.shape[0]), 2500)
            self.data = self.data[sample_list, :]
            self.label = self.label[sample_list]
        else:
            sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0])
            dup_data = self.data[sample_list, :]
            dup_label = self.label[sample_list]
            self.data = np.concatenate([self.data, dup_data], 0)
            self.label = np.concatenate([self.label, dup_label], 0)

        self.label = torch.tensor(self.label)
        self.label = self.label.type(torch.LongTensor)
        self.data = torch.tensor(self.data.T)
        #label要是Longtensor,data要是float32
        self.data = self.data.to(torch.float32)

        return self.data, self.label

    def __len__(self):
        return len(self.label_list)



class claDataset():
    def __init__(self, root):
        super(claDataset, self).__init__()
        dataset = h5py.File(root, 'r')
        self.data = dataset['data'][:]
        self.label = dataset['label'][:][:, 0]
    def __getitem__(self, index):
        label = torch.tensor(self.label[index])
        label = label.type(torch.LongTensor)
        return torch.tensor(self.data[index].T), label
    def __len__(self):
        return len(self.label)
View Code

configuration.py:

import torch.cuda


class config():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset'
    checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint'
    cladataset_root = 'H:/DataSet/modelnet40_ply_hdf5_2048/ply_data_train0.h5'
    num_epochs = 10
    batch_size = 4
    num_seg = 50
    num_classes = 40
View Code

Model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as Dataloader
import torch.utils.data as data
from configuration import config

cfg = config()

def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S],其中S可以是一个多项式
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points



def farthest_sample(data, sample_n):
    #data = data.transpose(1, 2)
    B, N, C = data.size()
    B_list = [i for i in range(B)]
    #dis标记每个点到采样点集合的距离
    dis = torch.ones((B, N)) * 1e-10
    #初始随机选一个点
    u = torch.randint(0, N, (B, ))
    vis = torch.ones((B, N))
    #vis标记某点是否被选
    vis[B_list, u] = 0
    ret = torch.zeros((B, sample_n), dtype = torch.long)
    for i in range(sample_n):
        ret[:, i] = u
        #取出点u
        cen = data[B_list, u, :].view(B, 1, 3)
        #求所有点到点u的距离
        distance = torch.sum((data - cen) ** 2, -1)
        #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新
        idx = torch.logical_and(dis < distance, vis)
        dis[idx] = distance[idx]
        u = torch.max(dis, -1)[1]
        #如果某点被选
        dis[B_list, u] = 0
        vis[B_list, u] = 0
    return index_points(data, ret)

def square_distance(data, center):
    #input: pointcloud_data center
    #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离
    N = data.size()[1]
    B, sample_n,_ = center.size()
    bl = torch.arange(B, dtype = torch.long)
    nl = torch.arange(N, dtype = torch.long)
    dis = torch.zeros((B, sample_n, N))
    for i in range(sample_n):
        coor = center[:, i, :].view(B, 1, 3)
        value = torch.sum((data - coor) ** 2, -1)
        dis[:, i, :] = value
    return dis


def query_ball_point(data, dis, radius, k):
    #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点
    #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标
    B, n, N = dis.size()
    group = torch.zeros([B, n, N], dtype=torch.long)
    B_list = torch.arange(B, dtype=torch.long)
    group[:, :, :] = torch.arange(N, dtype=torch.long)
    # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁
    # 这里存的是下标
    group[dis > radius ** 2] = N
    group = group.sort(dim=-1)[0][:, :, : k]
    # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧
    idx = group == N
    defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k)
    group[idx] = defa[idx]

    return group


# def cal_coor(data, group):
#     B, n, k = group.size()
#     B_list = torch.arange(B, dtype = torch.long)
#     res = torch.zeros([B, n, k, 3])
#     for i in range(n):
#         for j in range(k):
#             res[:, i, j, :] = data[B_list, group[:, i, j]]
#     return res

def sample_and_group(xyz, point, radius, nsample, dis):
    group = query_ball_point(xyz, dis, radius, nsample)
    new_xyz = index_points(xyz, group) # B * n * k * 3
    new_point = index_points(point, group) # B * n * k * C
    return new_xyz, new_point




class PointNetSetAbstractionMsg(nn.Module):
    #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表
    def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None):
        super().__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel
            for j in range(len(mlp_list[i])):
                convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积
                bns.append(nn.BatchNorm2d(mlp_list[i][j]))
                last_channel = mlp_list[i][j]
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)
    #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积
    #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定
    def forward(self, xyz, point):
        xyz = xyz.transpose(1, 2)
        if point is not None :
            point = point.transpose(1, 2)
        if self.npoint is not None:
            center = farthest_sample(xyz, self.npoint)
            dis = square_distance(xyz, center)
            point_list = []
            for i in range(len(self.radius_list)):
                new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis)
                new_point = new_point.permute(0, 3, 2, 1)
                for j in range(len(self.conv_blocks[i])):
                    conv = self.conv_blocks[i][j]
                    bn = self.bn_blocks[i][j]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim=2)[0]  # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组
                point_list.append(new_point)
            new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合
            center = center.transpose(2, 1)
            return center, new_point
        else:
            center = torch.zeros(xyz.size(0), 1, xyz.size(2))
            new_point = point.view(point.size(0), 1, point.size(1), point.size(2))
            new_point = new_point.permute(0, 3, 2, 1)
            for j in range(len(self.conv_blocks[0])):
                conv = self.conv_blocks[0][j]
                bn = self.bn_blocks[0][j]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = center.transpose(1, 2)

            return center, new_point

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False):
        super(PointNetSetAbstraction, self).__init__()
        self.group_all = group_all
        if self.group_all == False:
            self.npoint = npoint
            self.radius = radius
            self.nsample = nsample
        self.conv_block = nn.ModuleList()
        self.bn_block = nn.ModuleList()
        last_channel = in_channel
        for i in range(len(mlp)):
            self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1))
            self.bn_block.append(nn.BatchNorm2d(mlp[i]))
            last_channel = mlp[i]
    def forward(self, xyz, point):
        xyz = xyz.transpose(1, 2)
        point = point.transpose(1 ,2)
        if self.group_all == False:
            center = farthest_sample(xyz, self.npoint)
            dis = square_distance(xyz, center)
            new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis)
            new_point = new_point.permute(0, 3, 2, 1)
            for i in range(len(self.conv_block)):
                conv = self.conv_block[i]
                bn = self.bn_block[i]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = center.transpose(1, 2)
            return center, new_point
        else:
            new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1)
            for i in range(len(self.conv_block)):
                conv = self.conv_block[i]
                bn = self.bn_block[i]
                new_point = F.relu(bn(conv(new_point)))
            new_point = torch.max(new_point, dim = 2)[0]
            center = torch.zeros(xyz.size(0), 1, xyz.size(2))
            center = center.transpose(1, 2)
            return center, new_point




class PointNetFeaturePropagation(nn.Module):
    def __init__(self, inchannel, mlp_list):
        super().__init__()
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        last_channel = inchannel
        for i in range(len(mlp_list)):
            self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1))
            self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i]))
            last_channel = mlp_list[i]
    def forward(self, xyz1, xyz2, point1, point2):
        #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet
        xyz1 = xyz1.transpose(1, 2)
        xyz2 = xyz2.transpose(1, 2)
        point2 = point2.transpose(1, 2)
        B, N, _ = xyz1.size()
        _, n, _ = xyz2.size()
        if n == 1:
            point2 = point2.repeat(1, N, 1)
        else:
            dis, idx = square_distance(xyz2, xyz1).sort(dim = -1)
            dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方
            idx = idx[:, :, 1 : 4] #三个最近的点的原坐标
            feature = index_points(point2, idx) # B * N * 3 * D
            w = 1.0 / (dis + 1e-8) # B * N * 3
            ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1
            ratio = w / ep_w # B * N * 3
            #每个系数都 * 相应的特征,并将每一维的三个特征加和
            point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D
        if point1 is not None:
            point1 = point1.transpose(1, 2)
            point = torch.cat([point1, point2], dim = -1)
        else:
            point = point2
        point = point.transpose(1, 2)
        for i in range(len(self.conv_blocks)):
            conv = self.conv_blocks[i]
            bn = self.bn_blocks[i]
            point = F.relu(bn(conv(point)))
        return point

class PointNet_add_cla(nn.Module):
    def __init__(self):
        super(PointNet_add_cla, self).__init__()
        self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128])
        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
        self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, cfg.num_classes)

    def forward(self, xyz):
        xyz1, point1 = self.sa1(xyz, xyz)
        xyz2, point2 = self.sa2(xyz1, point1)
        xyz3, point3 = self.sa3(xyz2, point2)
        point = point3.view(point3.size(0), -1)
        point = F.dropout(F.relu(self.fc1(point)), 0.5)
        point = F.dropout(F.relu(self.fc2(point)), 0.5)
        point = self.fc3(point)
        return point




class PointNet_add(nn.Module):
    def __init__(self):
        super(PointNet_add, self).__init__()
        self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128])
        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
        self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True)
        self.fp1 = PointNetFeaturePropagation(1280, [256, 256])
        self.fp2 = PointNetFeaturePropagation(384, [256, 128])
        self.fp3 = PointNetFeaturePropagation(131, [128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.conv2 = nn.Conv1d(128, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, cfg.num_seg, 1)

    def forward(self, xyz):
        xyz1, point1 = self.sa1(xyz, xyz)
        xyz2, point2 = self.sa2(xyz1, point1)
        xyz3, point3 = self.sa3(xyz2, point2)
        point2 = self.fp1(xyz2, xyz3, point2, point3)
        # print(point2)
        point1 = self.fp2(xyz1, xyz2, point1, point2)
        # print(point1)
        point0 = self.fp3(xyz, xyz1, xyz, point1)
        point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
        point = F.dropout(F.relu(self.bn2(self.conv2(point))), 0.5)
        point = self.conv3(point)
        point = point.transpose(1, 2).contiguous()
        point = point.view(-1, cfg.num_seg)
        return point




class PointNet_add_Msg(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]])
        self.fp1 = PointNetFeaturePropagation(1664, [256, 256])
        self.fp2 = PointNetFeaturePropagation(576, [256, 128])
        self.fp3 = PointNetFeaturePropagation(131, [128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.conv2 = nn.Conv1d(128, cfg.num_seg, 1)

        self.bn1 = nn.BatchNorm1d(128)

    def forward(self, xyz):
        xyz1, point1 = self.sa1(xyz, xyz)
        xyz2, point2 = self.sa2(xyz1, point1)
        xyz3, point3 = self.sa3(xyz2, point2)
        point2 = self.fp1(xyz2, xyz3, point2, point3)
        point1 = self.fp2(xyz1, xyz2, point1, point2)
        point0 = self.fp3(xyz, xyz1, xyz, point1)
        point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
        point = self.conv2(point)
        point = point.transpose(1, 2).contiguous()
        point = point.view(-1, cfg.num_seg)
        return point
View Code

train.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from DataSet import Dataset, claDataset
from Model import PointNet_add_Msg, PointNet_add, PointNet_add_cla
from configuration import config
import torch.utils.data.dataloader as DataLoader
from tensorboardX import SummaryWriter
import os


cfg = config()

if __name__ == '__main__':

    # model = PointNet_add_Msg()
    # model = PointNet_add()
    model = PointNet_add_cla()
    model.to(cfg.device)
    # dataset = Dataset(cfg.dataset_root)
    dataset = claDataset(cfg.cladataset_root)
    dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    loss = nn.CrossEntropyLoss()
    tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log'))
    model.train()
    for epoch in range(cfg.num_epochs):
        total_true = 0
        total_loss = 0
        cnt = 0
        for xyz, label in tqdm(dataloader):
            optimizer.zero_grad()
            output = model(xyz)
            label = label.view(-1, 1)[:, 0]
            loss_value = loss(output, label)
            loss_value.backward()
            optimizer.step()
            pred = torch.max(output, -1)[1]
            total_true += torch.sum(pred == label)
            total_loss += loss_value
            cnt += 1
        mean_loss = total_loss / float(cnt)
        # accuracy = total_true / float(len(dataset) * 2500)
        accuracy = total_true / float(len(dataset))
        tbwrite.add_scalar('Loss', mean_loss, epoch)
        tbwrite.add_scalar('Accuracy', accuracy, epoch)
        print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy))
        if (epoch + 1) % cfg.num_epochs == 0:
            state = {
                'model': model.state_dict()
            }
            torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch + 1)))
View Code

 

posted @ 2021-10-15 22:03  WTSRUVF  阅读(227)  评论(0编辑  收藏  举报