探究grid_sample函数
一、函数介绍
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
-
对于4D输入,
input
维度为 \((N,C,H_{in},W_{in})\),grid
维度为 \((N,H_{out},W_{out},2)\) ,则output
维度为 \((N,C,H_{out},W_{out})\) -
对于5D输入,
input
维度为 \((N,C,D_{in},H_{in},W_{in})\),grid
维度为 \((N,D_{out},H_{out},W_{out},3)\) ,则output
维度为 \((N,C,D_{out},H_{out},W_{out})\) -
gird
储存着用于在输入特征图上进行元素采样的坐标偏移量。grid的元素值通常在 \(\left [-1, 1 \right ]\) 之间, \(\left (-1, -1 \right )\) 表示取输入特征图左上角的元素, \(\left (1, 1 \right )\) 表示取输入特征图右下角的元素。
二、示例代码
import torch
import torch.nn.functional as F
# 定义一个 4x4 的输入张量
input_tensor = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
], dtype=torch.float).view(1, 1, 4, 4)
print(input_tensor)
# 定义采样点,归一化坐标在 [-1, 1] 范围内
grid = torch.tensor([[
[[-0.5, -0.5],
[0.5, -0.5]],
[[-0.5, 0.5],
[0.5, 0.5]],
]], dtype=torch.float)
print(grid)
# 使用 F.grid_sample 进行采样
output = F.grid_sample(input_tensor, grid, align_corners=True)
print(output)
计算过程
假设输入张量的尺寸为 (4, 4)
,采样点坐标的归一化范围在 [-1, 1]
,我们将其转换为张量坐标的范围 [0, 3]
。
归一化坐标转换公式
归一化坐标转换公式如下:
示例计算 1:归一化采样点 [-0.5, -0.5]
对于归一化采样点 [-0.5, -0.5]
,我们将其转换为输入张量的实际坐标:
这样,归一化坐标 [-0.5, -0.5]
对应的输入张量实际坐标为 [0.75, 0.75]
。
假设采样点 (x, y)
对应输入张量的坐标 [0.75, 0.75]
,我们可以确定其周围的四个像素值:
左上角像素 (0, 0)
右上角像素 (0, 1)
左下角像素 (1, 0)
右下角像素 (1, 1)
使用双线性插值公式计算插值值:
top_left = input_tensor[0, 0, 0, 0] # 1
top_right = input_tensor[0, 0, 0, 1] # 2
bottom_left = input_tensor[0, 0, 1, 0] # 5
bottom_right = input_tensor[0, 0, 1, 1] # 6
value = (1-0.75)*(1-0.75)*f(0,0) + (1-0.25)*(1-0.75)*f(0,1) \
+ (1-0.75)*(1-0.25)*f(1,0) + (1-0.25)*(1-0.25)*f(1,1)
value = (1 - 0.75) * (1 - 0.75) * 1 + 0.75 * (1 - 0.75) * 2 \
+ (1 - 0.75) * 0.75 * 5 + 0.75 * 0.75 * 6 = 4.75
补充知识:
1.性插值法(linear interpolation)
假设我们已知坐标 (x0, y0)
与 (x1, y1)
,要得到 [x0, x1]
区间内某一位置 x
在直线上的值。根据图中所示,我们得到
由于 x
值已知,所以可以从公式得到 y
的值
2.双线性插值法(bilinear interpolation)
在数学上,双线性插值是有两个变量的插值函数的线性插值扩展,其核心思想是在两个方向分别进行一次线性插值。
如坐标图所示,用横纵坐标代表图像像素的位置,f(x,y)
代表该像素点 (x,y)
的彩色值或灰度值。
假设我们已知函数f(x,y)
在 Q11=(x1, y1)
、Q12=(x1, y2)
, Q21=(x2, y1)
以及 Q22=(x2, y2)
四个点的值
若想得到未知函数f(x,y)
在点P=(x, y)
的值,首先在 x 方向进行线性插值,得到
然后在 y 方向进行线性插值,得到
这样就得到所要的结果 f(x, y)
,
2.1 单位正方形
如果选择一个坐标系统使得 f(x,y)
的四个已知点坐标分别为 (0, 0)
、(0, 1)
、(1, 0)
和 (1, 1)
,那么插值公式就可以化简为
或者用矩阵运算表示为
2.2 非线性
双线性插值的结果不是线性的,它是两个线性函数的积。在单位正方形上,双线性插值可以记作
常数的数目(4个)对应于给定的 f(x,y)
的数据点数目
双线性插值的结果与插值的顺序无关。首先进行 y
方向的插值,然后进行 x
方向的插值,所得到的结果是一样的。双线性插值的一个显然的三维空间延伸是三线性插值。
参考文章: