关于神经网络初始化 torch.nn.init.xavier_normal_(rand_conv.weight.data)的问题

def random_convolution(imgs):
'''
random covolution in "network randomization"

(imbs): B x (C x stack) x H x W, note: imgs should be normalized and torch tensor
'''
_device = imgs.device

img_h, img_w = imgs.shape[2], imgs.shape[3]
num_stack_channel = imgs.shape[1]
num_batch = imgs.shape[0]

num_trans = num_batch
batch_size = int(num_batch / num_trans)

# initialize random covolution
rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device)

for trans_index in range(num_trans):
torch.nn.init.xavier_normal_(rand_conv.weight.data)
temp_imgs = imgs[trans_index*batch_size:(trans_index+1)*batch_size]

temp_imgs = temp_imgs.reshape(-1, 3, img_h, img_w) # (batch x stack, channel, h, w)

rand_out = rand_conv(temp_imgs)

rand_out_ = rand_conv(imgs.reshape(-1, 3, 84, 84))

if trans_index == 0:
total_out = rand_out
else:
total_out = torch.cat((total_out, rand_out), 0)

return total_out

在上面的代码中,torch.nn.init.xavier_normal_(rand_conv.weight.data)
作为网络的初始化操作,对后续的操作都有影响,注意,这里之所以没有对(128,9,84,84)的图像统一进行操作,就是因为目的是对每个(1,9,84,84)使用不同的网络初始化参数进行处理,
而不是对128个图像stack使用统一的网络参数进行处理。
posted @ 2022-06-24 16:42  呦呦南山  阅读(826)  评论(0编辑  收藏  举报