[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint

Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,
title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},
author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},
journal={arXiv: Learning},
year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

在这里插入图片描述

(6)argminξLξHk=1KϵkLk[u]iξiLLk[Φi(x,ξH)]2(Xk)2.

LSGD

固定ξH,Xk, 并令ϵk=1, 则问题(6)退化为一个最小二乘问题

argminξLAξLb2(X)2,

其中bi=L[u](xi), Aij=L[Φj(xi,ξH)], xiX,i=1,,N,j=1,,w.

所以算法如下

在这里插入图片描述

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第l层的输入为xRd1, 输出为yRd2, 则该层的权重矩阵WRd2×d1. 我们逐行地定义W:

  1. 采样p, pU[0,1]d1;
  2. 采样n, nN(0,Id1), 并令n=n/n;
  3. 求参数k使得

maxx[0,1]d1σ(k(xp)n)=1.

  1. Wiwi=knT, bi=kpn.

其中σ表示激活函数, 文中指的是ReLU.
求解参数k:

  1. pmax=max(0,sign(n));
  2. k=1(pmaxp)n

k即为所需k, 只需证明pmax是最大化

(xp)n,x[0,1]d1

的解. 最大化上式, 可以分解为

maxxi[0,1]xini,

xi=max(0,sign(ni)).

这个初始化有什么好处呢, 可以发现, 输入x[0,1]d1满足, 则输出y[0,1]d2, 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.

在这里插入图片描述
如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外b作者应该是遗漏了负号).

在这里插入图片描述

Box for Resnet

因为Resnet特殊的结构,

y=(W+I)x+b.

假设x[0,m]d1, 则:

  1. 采样p, pU[0,m]d1;
  2. 采样n, nN(0,Id1), 并令n=n/n;
  3. 求参数k使得

maxx[0,m]d1σ(k(xp)n)=δm.

  1. Wiwi=knT, bi=kpn.

k=δm(mpmaxp)n.

若第一层输入xi[0,1], 去δ=1/L, 其中L为总的层数, 则

[0,1][0,1+1L][0,(1+1L)2]

在这里插入图片描述

代码




'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings





def generate(size, m, delta):
    p = torch.rand(size) * m
    n = torch.randn(size)
    temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
    n = temp * n
    pmax = nn.functional.relu(torch.sign(n)) * m
    temp = (pmax - p) * n
    k = (m * delta) / temp.sum(dim=1, keepdim=True)
    w = k * n
    b = -(w * p).sum(dim=1)
    return w, b

def box_init(module, m=1, delta=1):
    if isinstance(module, nn.Linear):
        w, b = generate(module.weight.shape, m, delta)
        try:
            module.weight.data = w
            module.bias.data = b
        except AttributeError as e:
            s = "Error: \n" + str(e) + "\n stops the initialization" \
                                       " for this module: {}".format(module)
            warnings.warn(s)

    elif isinstance(module, nn.Conv2d):
        outc, inc, h, w = module.weight.size()
        w, b = generate((outc, inc * h * w), m, delta)
        try:
            module.weight.data = w.reshape(module.weight.size())
            module.bias.data = b
        except AttributeError as e:
            s = "Error: \n" + str(e) + "\n stops the initialization" \
                                       " for this module: {}".format(module)
            warnings.warn(s)

    else:
        pass




"""config.py"""

nums = 10
layers = 6
method = "kaiming"  #box/xavier/kaiming
net = "Net"  #Net/ResNet






"""
测试
"""



import torch
import torch.nn as nn
import config
from initialization import box_init



class Net(nn.Module):

    def __init__(self, l):
        super(Net, self).__init__()

        self.linears = []
        for i in range(l):
            name = "linear" + str(i)
            self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
                                                 nn.ReLU()))
            self.linears.append(self.__getattr__(name))
        if config.method == 'box':
            self.box_init()
        elif config.method == "xavier":
            self.xavier_init()
        else:
            self.kaiming_init()

    def box_init(self):
        for module in self.modules():
            box_init(module)

    def xavier_init(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(module.weight)

    def kaiming_init(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(module.weight)

    def forward(self, x):
        out = []
        temp = x
        for linear in self.linears:
            temp = linear(temp)
            out.append(temp)
        return out



class ResNet(nn.Module):

    def __init__(self, l):
        super(ResNet, self).__init__()

        self.linears = []
        for i in range(l):
            name = "linear" + str(i)
            self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
                                                 nn.ReLU()))
            self.linears.append(self.__getattr__(name))
        if config.method == 'box':
            self.box_init(l)
        elif config.method == "xavier":
            self.xavier_init()
        else:
            self.kaiming_init()

    def box_init(self, layers):
        delta = 1 / layers
        m = 1. + delta
        l = 0
        for module in self.modules():
            if isinstance(module, (nn.Linear)):
                if l == 0:
                    box_init(module, 1, 1)
                else:
                    box_init(module, m ** l, delta)
                l += 1

    def xavier_init(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(module.weight)

    def kaiming_init(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(module.weight)

    def forward(self, x):
        out = []
        temp = x
        for linear in self.linears:
            temp = linear(temp) + temp
            out.append(temp)
        return out


if config.net == "Net":
    net = Net(config.layers)
else:
    net = ResNet(config.layers)

x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums)

grid_x, grid_y = torch.meshgrid(x, y)

x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data)


import  matplotlib.pyplot as plt


def axplot(x, y, ax):
    x = x.detach().numpy()
    y = y.detach().numpy()
    ax.scatter(x, y)

def plot(x, y, outs):
    fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
    axs[0].scatter(x, y)
    axs[0].set(title="layer0")
    for i in range(config.layers):
        ax = axs[i+1]
        out = outs[i]
        x = out[:, 0]
        y = out[:, 1]
        axplot(x, y, ax)
        ax.set(title="layer"+str(i+1))
    plt.tight_layout()
    plt.savefig("C:/Users/pkavs/Desktop/fig.png")
    #plt.show()
plot(x, y, outs)







posted @   馒头and花卷  阅读(249)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
历史上的今天:
2019-04-23 Robust PCA via Outlier Pursuit
点击右上角即可分享
微信分享提示