lss

1. create_frustum

    def create_frustum(self):
        # make grid in image plane
        ogfH, ogfW = self.data_aug_conf['final_dim'] #ogfH:128 ogfW:352
        fH, fW = ogfH // self.downsample, ogfW // self.downsample #downsample16  fH:8  fW:22
        #shape tmp(41) 4,5,6,...,44                       #grid_conf['dbound'] [4,45,1]
        tmp = torch.arange(*self.grid_conf['dbound'], dtype=torch.float)
        #ds [41, 8, 22]
        ds = torch.arange(*self.grid_conf['dbound'], dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)
        D, _, _ = ds.shape #D=41

        #tmp_xs[22]    ogfW:352  fW:22  fH:8  fW:22           (351/21=16.7143)
        tmp_xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float)
        #tmp_ys[8]            ogfH:128  fH:8  (127/7=18.1429)
        tmp_ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float)

        # xs, ys [41, 8, 22]
        xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW)
        ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW)

        # D x H x W x 3   [41, 8, 22, 3]
        frustum = torch.stack((xs, ys, ds), -1)
        return nn.Parameter(frustum, requires_grad=False)

test_create_frustum.py

import torch

downsample = 3
ogfH = 21
ogfW = 36

#fH=7  fW=12
fH, fW = ogfH // downsample, ogfW // downsample

# #[5, 10]
# tmp0 = torch.arange(*[5, 15, 5], dtype=torch.float)
#ds[2,7,12]
ds = torch.arange(*[5, 15, 5], dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)
#D=2
D, _, _ = ds.shape

#tmp_xs[12]    ogfW:36  fW:12  fH:7            (35/11=3.1818)
tmp_xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float)
#tmp_ys[7]            ogfH:21  fH:7  (20/6=3.333)
tmp_ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float)

#[1, 1, 12]
tmp_xs_view = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(1, 1, fW)
#[1, 7, 1]
tmp_ys_view = torch.linspace(0, ogfH - 1, fH, dtype=torch.float).view(1, fH, 1)

# xs, ys [2, 7, 12]
xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW)
ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW)

# D x H x W x 3   [2, 7, 12, 3]
frustum = torch.stack((xs, ys, ds), -1)

print("==============>>>>>>>>>>>>>>>tmp_xs_view.shape=", tmp_xs_view.shape)
print(tmp_xs_view)

print("==============>>>>>>>>>>>>>>>xs.shape=", xs.shape)
print(xs)


print("==============>>>>>>>>>>>>>>>tmp_ys_view.shape=", tmp_ys_view.shape)
print(tmp_ys_view)

print("==============>>>>>>>>>>>>>>>ys.shape=", ys.shape)
print(ys)

print("==============>>>>>>>>>>>>>>>ds.shape=", ds.shape)
print(ds)

print("==============>>>>>>>>>>>>>>>frustum.shape=", frustum.shape)
print(frustum)

/media/algo/data_1/software/anconda_install/envs/pytorch1.7.0_general/bin/python3 /media/algo/data_1/project_others/0000paper/lss/project/lift-splat-shoot-master/0000/create_frustum.py
==============>>>>>>>>>>>>>>>tmp_xs_view.shape= torch.Size([1, 1, 12])
tensor([[[ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000]]])
==============>>>>>>>>>>>>>>>xs.shape= torch.Size([2, 7, 12])
tensor([[[ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000]],

        [[ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000],
         [ 0.0000,  3.1818,  6.3636,  9.5455, 12.7273, 15.9091, 19.0909,
          22.2727, 25.4545, 28.6364, 31.8182, 35.0000]]])
==============>>>>>>>>>>>>>>>tmp_ys_view.shape= torch.Size([1, 7, 1])
tensor([[[ 0.0000],
         [ 3.3333],
         [ 6.6667],
         [10.0000],
         [13.3333],
         [16.6667],
         [20.0000]]])
==============>>>>>>>>>>>>>>>ys.shape= torch.Size([2, 7, 12])
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 3.3333,  3.3333,  3.3333,  3.3333,  3.3333,  3.3333,  3.3333,
           3.3333,  3.3333,  3.3333,  3.3333,  3.3333],
         [ 6.6667,  6.6667,  6.6667,  6.6667,  6.6667,  6.6667,  6.6667,
           6.6667,  6.6667,  6.6667,  6.6667,  6.6667],
         [10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000],
         [13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333,
          13.3333, 13.3333, 13.3333, 13.3333, 13.3333],
         [16.6667, 16.6667, 16.6667, 16.6667, 16.6667, 16.6667, 16.6667,
          16.6667, 16.6667, 16.6667, 16.6667, 16.6667],
         [20.0000, 20.0000, 20.0000, 20.0000, 20.0000, 20.0000, 20.0000,
          20.0000, 20.0000, 20.0000, 20.0000, 20.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 3.3333,  3.3333,  3.3333,  3.3333,  3.3333,  3.3333,  3.3333,
           3.3333,  3.3333,  3.3333,  3.3333,  3.3333],
         [ 6.6667,  6.6667,  6.6667,  6.6667,  6.6667,  6.6667,  6.6667,
           6.6667,  6.6667,  6.6667,  6.6667,  6.6667],
         [10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000],
         [13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333,
          13.3333, 13.3333, 13.3333, 13.3333, 13.3333],
         [16.6667, 16.6667, 16.6667, 16.6667, 16.6667, 16.6667, 16.6667,
          16.6667, 16.6667, 16.6667, 16.6667, 16.6667],
         [20.0000, 20.0000, 20.0000, 20.0000, 20.0000, 20.0000, 20.0000,
          20.0000, 20.0000, 20.0000, 20.0000, 20.0000]]])
==============>>>>>>>>>>>>>>>ds.shape= torch.Size([2, 7, 12])
tensor([[[ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.]],

        [[10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
         [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.]]])
==============>>>>>>>>>>>>>>>frustum.shape= torch.Size([2, 7, 12, 3])
tensor([[[[ 0.0000,  0.0000,  5.0000],
          [ 3.1818,  0.0000,  5.0000],
          [ 6.3636,  0.0000,  5.0000],
          [ 9.5455,  0.0000,  5.0000],
          [12.7273,  0.0000,  5.0000],
          [15.9091,  0.0000,  5.0000],
          [19.0909,  0.0000,  5.0000],
          [22.2727,  0.0000,  5.0000],
          [25.4545,  0.0000,  5.0000],
          [28.6364,  0.0000,  5.0000],
          [31.8182,  0.0000,  5.0000],
          [35.0000,  0.0000,  5.0000]],

         [[ 0.0000,  3.3333,  5.0000],
          [ 3.1818,  3.3333,  5.0000],
          [ 6.3636,  3.3333,  5.0000],
          [ 9.5455,  3.3333,  5.0000],
          [12.7273,  3.3333,  5.0000],
          [15.9091,  3.3333,  5.0000],
          [19.0909,  3.3333,  5.0000],
          [22.2727,  3.3333,  5.0000],
          [25.4545,  3.3333,  5.0000],
          [28.6364,  3.3333,  5.0000],
          [31.8182,  3.3333,  5.0000],
          [35.0000,  3.3333,  5.0000]],

         [[ 0.0000,  6.6667,  5.0000],
          [ 3.1818,  6.6667,  5.0000],
          [ 6.3636,  6.6667,  5.0000],
          [ 9.5455,  6.6667,  5.0000],
          [12.7273,  6.6667,  5.0000],
          [15.9091,  6.6667,  5.0000],
          [19.0909,  6.6667,  5.0000],
          [22.2727,  6.6667,  5.0000],
          [25.4545,  6.6667,  5.0000],
          [28.6364,  6.6667,  5.0000],
          [31.8182,  6.6667,  5.0000],
          [35.0000,  6.6667,  5.0000]],

         [[ 0.0000, 10.0000,  5.0000],
          [ 3.1818, 10.0000,  5.0000],
          [ 6.3636, 10.0000,  5.0000],
          [ 9.5455, 10.0000,  5.0000],
          [12.7273, 10.0000,  5.0000],
          [15.9091, 10.0000,  5.0000],
          [19.0909, 10.0000,  5.0000],
          [22.2727, 10.0000,  5.0000],
          [25.4545, 10.0000,  5.0000],
          [28.6364, 10.0000,  5.0000],
          [31.8182, 10.0000,  5.0000],
          [35.0000, 10.0000,  5.0000]],

         [[ 0.0000, 13.3333,  5.0000],
          [ 3.1818, 13.3333,  5.0000],
          [ 6.3636, 13.3333,  5.0000],
          [ 9.5455, 13.3333,  5.0000],
          [12.7273, 13.3333,  5.0000],
          [15.9091, 13.3333,  5.0000],
          [19.0909, 13.3333,  5.0000],
          [22.2727, 13.3333,  5.0000],
          [25.4545, 13.3333,  5.0000],
          [28.6364, 13.3333,  5.0000],
          [31.8182, 13.3333,  5.0000],
          [35.0000, 13.3333,  5.0000]],

         [[ 0.0000, 16.6667,  5.0000],
          [ 3.1818, 16.6667,  5.0000],
          [ 6.3636, 16.6667,  5.0000],
          [ 9.5455, 16.6667,  5.0000],
          [12.7273, 16.6667,  5.0000],
          [15.9091, 16.6667,  5.0000],
          [19.0909, 16.6667,  5.0000],
          [22.2727, 16.6667,  5.0000],
          [25.4545, 16.6667,  5.0000],
          [28.6364, 16.6667,  5.0000],
          [31.8182, 16.6667,  5.0000],
          [35.0000, 16.6667,  5.0000]],

         [[ 0.0000, 20.0000,  5.0000],
          [ 3.1818, 20.0000,  5.0000],
          [ 6.3636, 20.0000,  5.0000],
          [ 9.5455, 20.0000,  5.0000],
          [12.7273, 20.0000,  5.0000],
          [15.9091, 20.0000,  5.0000],
          [19.0909, 20.0000,  5.0000],
          [22.2727, 20.0000,  5.0000],
          [25.4545, 20.0000,  5.0000],
          [28.6364, 20.0000,  5.0000],
          [31.8182, 20.0000,  5.0000],
          [35.0000, 20.0000,  5.0000]]],


        [[[ 0.0000,  0.0000, 10.0000],
          [ 3.1818,  0.0000, 10.0000],
          [ 6.3636,  0.0000, 10.0000],
          [ 9.5455,  0.0000, 10.0000],
          [12.7273,  0.0000, 10.0000],
          [15.9091,  0.0000, 10.0000],
          [19.0909,  0.0000, 10.0000],
          [22.2727,  0.0000, 10.0000],
          [25.4545,  0.0000, 10.0000],
          [28.6364,  0.0000, 10.0000],
          [31.8182,  0.0000, 10.0000],
          [35.0000,  0.0000, 10.0000]],

         [[ 0.0000,  3.3333, 10.0000],
          [ 3.1818,  3.3333, 10.0000],
          [ 6.3636,  3.3333, 10.0000],
          [ 9.5455,  3.3333, 10.0000],
          [12.7273,  3.3333, 10.0000],
          [15.9091,  3.3333, 10.0000],
          [19.0909,  3.3333, 10.0000],
          [22.2727,  3.3333, 10.0000],
          [25.4545,  3.3333, 10.0000],
          [28.6364,  3.3333, 10.0000],
          [31.8182,  3.3333, 10.0000],
          [35.0000,  3.3333, 10.0000]],

         [[ 0.0000,  6.6667, 10.0000],
          [ 3.1818,  6.6667, 10.0000],
          [ 6.3636,  6.6667, 10.0000],
          [ 9.5455,  6.6667, 10.0000],
          [12.7273,  6.6667, 10.0000],
          [15.9091,  6.6667, 10.0000],
          [19.0909,  6.6667, 10.0000],
          [22.2727,  6.6667, 10.0000],
          [25.4545,  6.6667, 10.0000],
          [28.6364,  6.6667, 10.0000],
          [31.8182,  6.6667, 10.0000],
          [35.0000,  6.6667, 10.0000]],

         [[ 0.0000, 10.0000, 10.0000],
          [ 3.1818, 10.0000, 10.0000],
          [ 6.3636, 10.0000, 10.0000],
          [ 9.5455, 10.0000, 10.0000],
          [12.7273, 10.0000, 10.0000],
          [15.9091, 10.0000, 10.0000],
          [19.0909, 10.0000, 10.0000],
          [22.2727, 10.0000, 10.0000],
          [25.4545, 10.0000, 10.0000],
          [28.6364, 10.0000, 10.0000],
          [31.8182, 10.0000, 10.0000],
          [35.0000, 10.0000, 10.0000]],

         [[ 0.0000, 13.3333, 10.0000],
          [ 3.1818, 13.3333, 10.0000],
          [ 6.3636, 13.3333, 10.0000],
          [ 9.5455, 13.3333, 10.0000],
          [12.7273, 13.3333, 10.0000],
          [15.9091, 13.3333, 10.0000],
          [19.0909, 13.3333, 10.0000],
          [22.2727, 13.3333, 10.0000],
          [25.4545, 13.3333, 10.0000],
          [28.6364, 13.3333, 10.0000],
          [31.8182, 13.3333, 10.0000],
          [35.0000, 13.3333, 10.0000]],

         [[ 0.0000, 16.6667, 10.0000],
          [ 3.1818, 16.6667, 10.0000],
          [ 6.3636, 16.6667, 10.0000],
          [ 9.5455, 16.6667, 10.0000],
          [12.7273, 16.6667, 10.0000],
          [15.9091, 16.6667, 10.0000],
          [19.0909, 16.6667, 10.0000],
          [22.2727, 16.6667, 10.0000],
          [25.4545, 16.6667, 10.0000],
          [28.6364, 16.6667, 10.0000],
          [31.8182, 16.6667, 10.0000],
          [35.0000, 16.6667, 10.0000]],

         [[ 0.0000, 20.0000, 10.0000],
          [ 3.1818, 20.0000, 10.0000],
          [ 6.3636, 20.0000, 10.0000],
          [ 9.5455, 20.0000, 10.0000],
          [12.7273, 20.0000, 10.0000],
          [15.9091, 20.0000, 10.0000],
          [19.0909, 20.0000, 10.0000],
          [22.2727, 20.0000, 10.0000],
          [25.4545, 20.0000, 10.0000],
          [28.6364, 20.0000, 10.0000],
          [31.8182, 20.0000, 10.0000],
          [35.0000, 20.0000, 10.0000]]]])

Process finished with exit code 0

2. 测试expand


import torch
a = torch.Tensor([[1, 2, 3],
                  [4, 5, 6]])
print(a.shape)

e0 = a.view(2, 3, 1)
e00 = e0.expand(2, 3, 5)
print(e0)

print()
print(e00)

print("=====================================")
e0 = a.view(2, 1, 3)
e00 = e0.expand(2, 5, 3)
print(e0)

print()
print(e00)
torch.Size([2, 3])
tensor([[[1.],
         [2.],
         [3.]],

        [[4.],
         [5.],
         [6.]]])

tensor([[[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]],

        [[4., 4., 4., 4., 4.],
         [5., 5., 5., 5., 5.],
         [6., 6., 6., 6., 6.]]])
=====================================
tensor([[[1., 2., 3.]],

        [[4., 5., 6.]]])

tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[4., 5., 6.],
         [4., 5., 6.],
         [4., 5., 6.],
         [4., 5., 6.],
         [4., 5., 6.]]])

Process finished with exit code 0

2.1

import torch
t1 = torch.Tensor([ 8, 9])

t2 = t1.view(-1, 1)
t22 = t2.expand(-1, 3)

t3 = t1.view(-1, 1, 1)
t33 = t3.expand(-1, 3, 4)

print("t1.shape=", t1.shape)
print(t1)

print("t2.shape=", t2.shape)
print(t2)
print("t22.shape=", t22.shape)
print(t22)

print("t3.shape=", t3.shape)
print(t3)
print("t33.shape=", t33.shape)
print(t33)
/media/algo/data_1/software/anconda_install/envs/pytorch1.7.0_general/bin/python3 /media/algo/data_1/project_others/0000paper/lss/project/lift-splat-shoot-master/0000/expand_test.py
t1.shape= torch.Size([2])
tensor([8., 9.])
t2.shape= torch.Size([2, 1])
tensor([[8.],
        [9.]])
t22.shape= torch.Size([2, 3])
tensor([[8., 8., 8.],
        [9., 9., 9.]])
t3.shape= torch.Size([2, 1, 1])
tensor([[[8.]],

        [[9.]]])
t33.shape= torch.Size([2, 3, 4])
tensor([[[8., 8., 8., 8.],
         [8., 8., 8., 8.],
         [8., 8., 8., 8.]],

        [[9., 9., 9., 9.],
         [9., 9., 9., 9.],
         [9., 9., 9., 9.]]])

Process finished with exit code 0

3. points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)

        B, N, _ = trans.shape
        #self.frustum [41, 8, 22, 3]   [D, H, W, 3]
        # undo post-transformation
        # B x N x D x H x W x 3           [41, 8, 22, 3] - [2, 5, 1, 1, 1, 3]
        points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
        points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))

测试demo

import torch

a = torch.randint(1, 100, (2, 3))
b = torch.randint(1, 100, (3, 1, 1, 3))
print(a)
print(b)

c = a - b

print(c)
print(c.shape)
tensor([[37,  8, 68],
        [58,  9,  5]])
tensor([[[[76,  1, 58]]],


        [[[61, 81, 75]]],


        [[[18,  9, 98]]]])
tensor([[[[-39,   7,  10],
          [-18,   8, -53]]],


        [[[-24, -73,  -7],
          [ -3, -72, -70]]],


        [[[ 19,  -1, -30],
          [ 40,   0, -93]]]])
torch.Size([3, 1, 2, 3])

Process finished with exit code 0
posted @ 2023-06-02 17:57  无左无右  阅读(118)  评论(0编辑  收藏  举报