CNN 网络层定义的输入
在CNN网络层定义中,发现了如下的问题:
其中红框的位置,应该是输入数据的维度,比如,我们这里的输入数据格式为: torch.Size([8, 4, 84, 84]),其中,8是batch-size, 4 为维度数,按说,红框位置处应为维度4,但这里却是直接使用input 数据,而不是input的格式或维度。
经研究,我们认为这是属于pytorch内在机制决定的,应该是无论是输入维度4还是整个数据的格式都是可以的,如果输入的是整个数据的格式,则会自动默认格式为:
并从中选择Cin作为定义网络层的维度。