影醉阏轩窗

衣带渐宽终不悔,为伊消得人憔悴。
扩大
缩小

工作小结三

torch.max()输入两个tensor

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

最近看源代码时候没看懂骚操作

def find_intersection(set_1, set_2):
    """
    Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.

    :param set_1: set 1, a tensor of dimensions (n1, 4)
    :param set_2: set 2, a tensor of dimensions (n2, 4)
    :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
    """

    # PyTorch auto-broadcasts singleton dimensions
    lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
    upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
    intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
    return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)

那里说求交集应该是两个边界X距离--两个框的宽度乘以两个边界Y距离--两个框的宽度即可

原来问题出在torch.max()上,简单的用法这里不再赘述,仅仅看最后一个用法,pytorch官方也是一笔带过

torch.max(input, other, out=None) → Tensor
Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise maximum is taken.

The shapes of input and other don’t need to match, but they must be broadcastable.

\text{out}_i = \max(\text{tensor}_i, \text{other}_i)
out_i=max( tensor_i,other_i )
NOTE

When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.

Parameters
input (Tensor) – the input tensor.

other (Tensor) – the second input tensor

out (Tensor, optional) – the output tensor.

Example:

>>> a = torch.randn(4)
>>> a
tensor([ 0.2942, -0.7416,  0.2653, -0.1584])
>>> b = torch.randn(4)
>>> b
tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
>>> torch.max(a, b)
tensor([ 0.8722, -0.7416,  0.2653, -0.1584])

正常如果如初两个shape相同的tensor,直接按元素比较即可

如果两个不同的tensor上面官方没有说明:

这里举个例子:输入aaa=2 * 2,bbb=2 * 3

aaa = torch.randn(2,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

出现以上的错误,这里先进行分析一下:

2 * 2 3 * 2无法直接进行比较,按照pytorch官方的说法逐元素比较,那么输出也就应该是2 * 3 * 2,我们进一步进行测试:

aaa = torch.randn(1,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
tensor([[1.0350, 0.2532],
        [0.2203, 0.2532],
        [0.2912, 0.2532]])

直接可以输出,不会报错

原来pytorch的原则是这样的:维度不同只能比较一维的数据

那么我们可以进一步测试,将输入的2 * 23 * 2转换成1 * 2 * 23 * 1 * 2

aaa = torch.randn(2,2).unsqueeze(1)
bbb = torch.randn(3,2).unsqueeze(0)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

好了,问题完美解决!有时间去看一下源代码怎么实现的,咋不智能。。。。

posted on 2019-11-17 20:53  影醉阏轩窗  阅读(4126)  评论(0编辑  收藏  举报

导航

/* 线条鼠标集合 */ /* 鼠标点击求赞文字特效 */ //带头像评论