nn.Linear输入输出维度的理解

目录

    最近写代码发现nn.Linear的输入维度其实可以不是一维的输入。

    一般来说,nn.Linear的输入一般是(batch_size, input_dimen)。但是官方文档说明了:
    input:(,input_dimen)
    output:(
    ,output_dimen)

    因此其实*是可以任意维度的。

    但是要注意的是,无论*的维度怎么样,真正的输入层和输出层神经元的个数和结构,是由input_dimen和output_dimen决定的,因此神经网络的参数也是其决定的。

    也就是说:无论的维度是多少,输入的数据都是相当于batch_size的作用,公用一套神经网络参数。

    官方文档的解释:
    https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

    image

    posted @ 2024-05-26 11:12  JaxonYe  阅读(36)  评论(0编辑  收藏  举报