torch.meshgrid函数

在此记录下torch.meshgrid的用法,该函数常常用于生成二维的网格:

1
2
3
4
5
6
7
8
9
10
11
>>> x = torch.tensor([123])
>>> y = torch.tensor([456])
>>> grid_x, grid_y = torch.meshgrid(x, y)
>>> grid_x
tensor([[111],
        [222],
        [333]])
>>> grid_y
tensor([[456],
        [456],
        [456]])

  另一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
>>> import torch
>>> h = 6
>>> w = 10
>>> ys,xs = torch.meshgrid(torch.arange(h), torch.arange(w))
>>> xs.shape
torch.Size([610])
>>> ys.shape
torch.Size([610])
>>> xs
tensor([[0123456789],
        [0123456789],
        [0123456789],
        [0123456789],
        [0123456789],
        [0123456789]])
>>> ys
tensor([[0000000000],
        [1111111111],
        [2222222222],
        [3333333333],
        [4444444444],
        [5555555555]])
>>> xys = torch.stack([xs, ys], dim=-1)
>>> xys.shape
torch.Size([6102])

  需要注意的点:

  1. torch.meshgrid函数的输入是若干个(N个)一维Tensor或者若干个标量。

  2. torch.meshgrid函数的输出有N个,每个输出都是N维的。

  3. torch.meshgrid函数的每个输出tensor的shape都为(d1,d2,d3...dN)(d1,d2,d3...dN),其中didi为第i个输入向量的长度。

  4. torch.meshgrid函数的每个输出有什么不同?答:为该输出对应输入向量在其他维度舒展开的结果。

  5. torch的meshgrid实现和numpy的meshgrid实现有所不同,后者“可能”能够更直接地获取我们需要的东西,而torch的meshgrid调用后可能还需要做一个转置。

posted @ 2022-09-11 10:49  海_纳百川  阅读(512)  评论(0编辑  收藏  举报
本站总访问量