pytorch转onnx中关于卷积核的问题
pytorch导出onnx过程中报如下错误:
RuntimeError: Unsupported: ONNX export of convolution for kernel of unknown shape.
我报错的部分代码如下:
def forward(self, input):
n, c, h, w = input.size()
s = self.scale_factor
# pad input (left, right, top, bottom)
input = F.pad(input, (0, 1, 0, 1), mode='replicate')
# calculate output (height)
kernel_h = self.kernels.repeat(c, 1).view(-1, 1, s, 1)
output = F.conv2d(input, kernel_h, stride=1, padding=0, groups=c)
output = output.reshape(
n, c, s, -1, w + 1).permute(0, 1, 3, 2, 4).reshape(n, c, -1, w + 1)
# calculate output (width)
kernel_w = self.kernels.repeat(c, 1).view(-1, 1, 1, s)
output = F.conv2d(output, kernel_w, stride=1, padding=0, groups=c)
output = output.reshape(
n, c, s, h * s, -1).permute(0, 1, 3, 4, 2).reshape(n, c, h * s, -1)
return output
原因是使用卷积函数(torch.nn.functional.conv2d,而非卷积层torch.nn.conv2d)时不能使kernel_size为可变值,即kernel_size不能受输入变量影响。
上述代码中kernel_h和kernel_w均由self.kernels变换而来,其中repeat()受输入变量input维度c影响,导致报错。
最简单的解决方法为,在开始加一句c=int(c),使变量c对onnx来说为int常数而非torch.Tensor,从而实现kernel_h和kernel_w的size固定,最终代码如下:
def forward(self, input):
n, c, h, w = input.size()
s = self.scale_factor
c = int(c)
# pad input (left, right, top, bottom)
input = F.pad(input, (0, 1, 0, 1), mode='replicate')
# calculate output (height)
kernel_h = self.kernels.repeat(c, 1).view(-1, 1, s, 1)
output = F.conv2d(input, kernel_h, stride=1, padding=0, groups=c)
output = output.reshape(
n, c, s, -1, w + 1).permute(0, 1, 3, 2, 4).reshape(n, c, -1, w + 1)
# calculate output (width)
kernel_w = self.kernels.repeat(c, 1).view(-1, 1, 1, s)
output = F.conv2d(output, kernel_w, stride=1, padding=0, groups=c)
output = output.reshape(
n, c, s, h * s, -1).permute(0, 1, 3, 4, 2).reshape(n, c, h * s, -1)
return output
但是这样同样会导致导出onnx时报Warning:
TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
c = int(c)
在我的应用中,c始终为定值,因此,不需要担心c值更改后无法跟踪的问题,所以无视之~