pytorch计算图扩大,反传变慢问题debug
是这样的,我自己写了一个block,这个block的内容如下
# 为了更加集成,给定两个角度,生成compact的倾斜图片
class Compact_Homo(nn.Module):
def __init__(self, device):
super(Compact_Homo, self).__init__()
# 假设内参数K为单位矩阵
self.d = 5 # 表示物体到光心的距离
self.device = device
def forward(self, alpha, beta, size, d):
# alpha: N, beta: N, size: N*C*W*H
# pdb.set_trace()
if d is not None:
self.d = d
B = alpha.shape[0]
# 表示图像的尺寸
if size is None:
size = (B, 3, 1024, 1024)
N, C, H, W = size
N = B
Rotx = torch.zeros(B, 3, 3).to(self.device).clone()
ones = torch.ones(B,).to(self.device).clone()
# pdb.set_trace()
Rotx[:, 0, 0] = ones
Rotx[:,1, 1] = torch.cos(beta).squeeze(1)
Rotx[:,1, 2] = -torch.sin(beta).squeeze(1)
Rotx[:,2, 1] = torch.sin(beta).squeeze(1)
Rotx[:,2, 2] = torch.cos(beta).squeeze(1)
Roty = torch.zeros(B, 3, 3).to(self.device).clone()
ones = torch.ones(B,).to(self.device).clone()
Roty[:,1,1] = ones.clone()
Roty[:,0,0] = torch.cos(alpha).squeeze(1)
Roty[:,0,2] = torch.sin(alpha).squeeze(1)
Roty[:,2,0] = -torch.sin(alpha).squeeze(1)
Roty[:,2,2] = torch.cos(alpha).squeeze(1)
# 以下过程构造homo
R = torch.bmm(Rotx, Roty)
R_1 = torch.inverse(R).clone() # 版本不一样,需要的shape也不一样
t = torch.zeros(B,3).to(self.device)
# pdb.set_trace()
t[:,2] = d.squeeze(1).clone() # 平移向量
R_1[:,:,2] = t.clone() # 将第三列赋值
temp_homo = R_1.clone()
homo = torch.inverse(R_1).clone()
# -------------------
# 以下过程构造单位圆,求解其center以及其scale
C = torch.zeros(B, 3, 3).to(self.device).clone()
C[:,0,0] = torch.tensor(1.)
C[:,1,1] = torch.tensor(1.)
C[:,2,2] = torch.tensor(-1.)
C2 = torch.bmm(torch.inverse(torch.transpose(temp_homo,1,2)), C)
C2_ = torch.bmm(C2, torch.inverse(temp_homo))
C3 = torch.inverse(C2_) # 对偶形式
a = C3[:,0,0]
b = C3[:,0,2]+C3[:,2,0]
c = C3[:,2,2]
right_x = (-b-torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
left_x = (-b+torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
right_x = -1./right_x
left_x = -1./left_x
width = right_x-left_x
center_x = (right_x+left_x)/2
a_ = C3[:,1,1]
b_ = C3[:,1,2]+C3[:,2,1]
c_ = C3[:,2,2]
bottom_y = (-b_-torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
top_y = (-b_+torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
bottom_y = -1./bottom_y
top_y = -1./top_y
height = bottom_y-top_y
center_y = (top_y+bottom_y)/2
scale = torch.max(width, height)
#---------------------
# 根据求解得到的homo,中心点以及产生compact的grid
# size = (1, 3, 1024, 1024)
N, C, H, W = size
N=B
base_grid = torch.zeros(N, H, W, 2).to(self.device)
linear_points = torch.linspace(-1, 1, W).to(self.device) if W > 1 else torch.Tensor([-1]).to(self.device)
base_grid[:, :, :, 0] = torch.ger(torch.ones(H).to(self.device), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H).to(self.device) if H > 1 else torch.Tensor([-1]).to(self.device)
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W).to(self.device)).expand_as(base_grid[:, :, :, 1])
base_grid = base_grid.view(N, H * W, 2)
# 对center和scale进行变换
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
base_grid = base_grid*scale/2
base_grid = base_grid+center
# 将homo进行扩展,方便运算
h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)
temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2
temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
v1 = temp3 / temp4
grid1 = u1.view(N, H, W, 1)
grid2 = v1.view(N, H, W, 1)
grid = torch.cat((grid1, grid2), 3)
return grid
但是我在主程序中调用这个block的时候,计算loss,并且反传大概需要20多秒,但是前传很快。
一开始是怀疑是torch.inverse或者是torch.sqrt这些函数会拖慢反传速度,但是后来想了一下拟操作或者开方的导数并不复杂。
在pytorch forum上网上看了一个链接,他提出的问题是计算图进行了极大的扩展,而一开始我并没有往这方面想。通过逐步debug,我发现将center以及scale进行detach()之后,运算时长会极大的缩短,所以我想的是一定是不用反传所以很快,时长能从20秒降低到6秒。
继续debug
我发现将上述代码中的一段
temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2
temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
其中的h换成homo中的一些元素,能保留前传的梯度,如果问题出现在torch.inverse或者torch.sqrt的话,理论上应该不会影响计算速度,但是我发现当我这么操作的时候,反传时间会极大的缩短。
于是我想之所以center和scale变量进行detach()的时候,计算时长也会极大缩短,原因可能是和repeat有关,因为h也是homo的repeat很多次(W*H),所以我果断将repeat给替换掉,
h = homo
# temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
# temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
temp1 = (h[:, 0, 0] * base_grid[:, :, 0] + h[:, 0, 1] * base_grid[:, :, 1] + h[:, 0, 2])
temp2 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
u1 = temp1 / temp2
# temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
# temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
temp3 = (h[:, 1, 0] * base_grid[:, :, 0] + h[:, 1, 1] * base_grid[:, :, 1] + h[:, 1, 2])
temp4 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
v1 = temp3 / temp4
# 对center和scale进行变换
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
# center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
# scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
center = torch.cat((center_x,center_y), 1)
scale = scale
base_grid = base_grid*scale/2.
base_grid = base_grid+center
所以时长一下子由下图
变成了
几乎不耗时
pytorch forum链接https://discuss.pytorch.org/t/why-loss-backward-is-so-slow-taking-about-20s/122956/3
posted on 2021-06-02 11:25 YongjieShi 阅读(295) 评论(0) 编辑 收藏 举报