人工智能实战2019 - 第3次作业 - 王铈弘

项目 内容
课程 人工智能实战2019
作业要求 第3次作业
课程目标 学习人工智能基础知识
本次作业对我的帮助 学习随机梯度下降的三种方法,理解损失函数图像的内涵
理论课程 梯度下降的三种形式

使用Mini-Batch方式进行梯度下降

要求

  • 采用随机选取数据的方式
  • batch size 分别选择5、10、15进行运行

代码实现

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

x_data_name = "TemperatureControlXData.dat"
y_data_name = "TemperatureControlYData.dat"


def ReadData():
    Xfile = Path(x_data_name)
    Yfile = Path(y_data_name)
    if Xfile.exists() & Yfile.exists():
        X = np.load(Xfile)
        Y = np.load(Yfile)
        return X.reshape(1,-1),Y.reshape(1,-1)
    else:
        return None,None

def shuffle_batch(X, Y, batch_size):    
    rnd_idx = np.random.permutation(len(X))
    n_batches = len(X)
    for batch_idx in np.array_split(rnd_idx, n_batches):
        X_batch, Y_batch = X[batch_idx], Y[batch_idx]
        yield X_batch, Y_batch

def ForwardCalculationBatch(W,B,batch_x):
    Z = np.dot(W, batch_x) + B
    return Z

def BackPropagationBatch(batch_x, batch_y, batch_z):
    m = batch_x.shape[1]
    dZ = batch_z - batch_y
    dB = dZ.sum(axis=1, keepdims=True)/m
    dW = np.dot(dZ, batch_x.T)/m
    return dW, dB

def UpdateWeights(w, b, dW, dB, eta):
    w = w - eta*dW
    b = b - eta*dB
    return w,b

def InitialWeights(num_input, num_output, flag):
    if flag == 0:
        # zero,全零初始化
        W = np.zeros((num_output, num_input))
    elif flag == 1:
        # normalize,高斯分布初始化
        W = np.random.normal(size=(num_output, num_input))
    elif flag == 2:
        # xavier,均匀分布初始化
        W=np.random.uniform(
            -np.sqrt(6/(num_input+num_output)),
            np.sqrt(6/(num_input+num_output)),
            size=(num_output,num_input))

    B = np.zeros((num_output, 1))
    return W,B

def CheckLoss(W, B, X, Y):
    m = X.shape[1]
    Z = np.dot(W, X) + B
    LOSS = (Z - Y)**2
    loss = LOSS.sum()/m/2
    return loss


if __name__ == '__main__':
    
    eta = 0.1
    size = {5,15,20}
    max_epoch = 50

    X, Y = ReadData()
   
    num_example = X.shape[1]
    num_feature = X.shape[0]

    for batch_size in size:
        W, B = InitialWeights(1,1,2)
        loss = []
        for epoch in range(max_epoch):
            for batch_x,batch_y in shuffle_batch(X,Y,batch_size):
                batch_z = ForwardCalculationBatch(W, B, batch_x)
                dW, dB = BackPropagationBatch(batch_x, batch_y, batch_z)
                W, B = UpdateWeights(W, B, dW, dB, eta)
     
            temp = CheckLoss(W,B,X,Y)
            loss.append(temp)
        plt.plot(loss)
        plt.legend(['batch_size : 5','batch_size : 10','batch_size : 15'],loc ='upper right')  
    plt.show()

运行结果

  • 学习率eta = 0.01
    At2Dkd.png
  • 学习率eta = 0.1
    At2sfI.png
  • 学习率eta = 0.2
    At26pt.png
  • 学习率eta = 0.5
    At2c1P.png

总结

  • 随着学习率eta的增大,会在极值点附近产生跳跃,曲线波动增大
  • batch size的的改变,会改变学习速度;batch size过大或过小都不利于提高学习速度
  • Mini-Batch综合了全批量梯度下降BGD和随机梯度下降SGD的优点,在更新速度与更新次数中取一个平衡。相较于BGD,提高了每次的学习速度,不必担心内存瓶颈;相较于SGD,降低了收敛波动性,使更新更加稳定。

梯度下降算法的问题与挑战

  • 很难选择一个合理的学习速率
  • 学习速率调整都需要事先进行固定设置,无法自适应每次学习的数据集特点(比如对于很少出现的特征,应用较大的学习速率)
  • 对于非凸目标函数,容易陷入局部最优(虽然使用SGD,会增大得到全局最优的可能性)

关于损失函数的2D示意图的思考题

1. 为什么是椭圆而不是圆?

由题设知:

\[J(w, b)=\frac{1}{2 m} \sum_{i}^{m}\left(a_{i}-y_{i}\right)^{2}=\frac{1}{2 m} \sum_{i}^{m}\left(\omega x+b-y_{i}\right)^{2} \]

令J=z,整理损失函数表达式可得:

\[\left(\sum_{i}^{m} x_{i}^{2}\right) w^{2}+m b^{2}+\left(2 \sum_{i}^{m} x_{i}\right) w b-\left(2 \sum_{i}^{m} x_{i} y_{i}\right) w-\left(2 \sum_{i}^{m} y_{i}\right) b+\left(\sum_{i}^{m} y_{i}^{2}\right)=2 m z \]

椭圆抛物面的标准方程可推知,变换后的损失函数表达式为椭圆抛物面的一般方程。故投影到2维平面是一般椭圆。

2. 如何把这个图变成一个圆?

只需使变换后的损失函数表达式与回转抛物面的一般方程一致即可。
交叉项前的系数为0,二次项前的系数相等:

\[\left(2 \sum_{i}^{m} x_{i}\right)=0 \quad \left(\sum_{i}^{m} x_{i}^{2}\right)=m \]

代码实现

    x = x - x.mean(axis=0)  #使交叉项前的系数为0
    x = x*np.sqrt(len(x)/np.sum(np.square(x)))    #使二次方前的系数相等

运行结果(主程序引用自:Microsoft/ai-edu

AYxZwj.png

3. 为什么中心是个椭圆区域而不是一个点?

  • 从数学角度分析,利用偏导数的知识可以求出唯一最优解,故椭圆抛物面中心是一个点。
  • 我们绘制的图像其实是由很多离散的点组成,并非连续曲面。由于我们使用的梯度下降算法属于数值解法,会无限逼近数学解析最优解,而满足给定误差限的点有无穷多个,故中心形成区域。
posted @ 2019-03-25 18:02  WangShihong  阅读(327)  评论(0编辑  收藏  举报