cs 231 Batch Normalization 求导推导及代码复现(BN,LN)

cs 231 Batch Normalization 求导推导及代码复现(BN,LN)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/duan_zhihua/article/details/83107615

cs 231 Batch Normalization 求导推导及代码复现:

作者论文公式:https://arxiv.org/abs/1502.03167

Batch Normalization 计算图:

Batch Normalization 求导数学推导:

Batch Normalization 对xi 三条路径最终推出的结果:

论文公式代码复现如下:

  1.  
    def batchnorm_forward(x, gamma, beta, bn_param):
  2.  
    """
  3.  
    Forward pass for batch normalization.
  4.  
     
  5.  
    During training the sample mean and (uncorrected) sample variance are
  6.  
    computed from minibatch statistics and used to normalize the incoming data.
  7.  
    During training we also keep an exponentially decaying running mean of the
  8.  
    mean and variance of each feature, and these averages are used to normalize
  9.  
    data at test-time.
  10.  
     
  11.  
    At each timestep we update the running averages for mean and variance using
  12.  
    an exponential decay based on the momentum parameter:
  13.  
     
  14.  
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
  15.  
    running_var = momentum * running_var + (1 - momentum) * sample_var
  16.  
     
  17.  
    Note that the batch normalization paper suggests a different test-time
  18.  
    behavior: they compute sample mean and variance for each feature using a
  19.  
    large number of training images rather than using a running average. For
  20.  
    this implementation we have chosen to use running averages instead since
  21.  
    they do not require an additional estimation step; the torch7
  22.  
    implementation of batch normalization also uses running averages.
  23.  
     
  24.  
    Input:
  25.  
    - x: Data of shape (N, D)
  26.  
    - gamma: Scale parameter of shape (D,)
  27.  
    - beta: Shift paremeter of shape (D,)
  28.  
    - bn_param: Dictionary with the following keys:
  29.  
    - mode: 'train' or 'test'; required
  30.  
    - eps: Constant for numeric stability
  31.  
    - momentum: Constant for running mean / variance.
  32.  
    - running_mean: Array of shape (D,) giving running mean of features
  33.  
    - running_var Array of shape (D,) giving running variance of features
  34.  
     
  35.  
    Returns a tuple of:
  36.  
    - out: of shape (N, D)
  37.  
    - cache: A tuple of values needed in the backward pass
  38.  
    """
  39.  
    mode = bn_param['mode']
  40.  
    eps = bn_param.get('eps', 1e-5)
  41.  
    momentum = bn_param.get('momentum', 0.9)
  42.  
     
  43.  
    N, D = x.shape
  44.  
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
  45.  
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
  46.  
     
  47.  
    out, cache = None, None
  48.  
    if mode == 'train':
  49.  
    #######################################################################
  50.  
    # TODO: Implement the training-time forward pass for batch norm. #
  51.  
    # Use minibatch statistics to compute the mean and variance, use #
  52.  
    # these statistics to normalize the incoming data, and scale and #
  53.  
    # shift the normalized data using gamma and beta. #
  54.  
    # #
  55.  
    # You should store the output in the variable out. Any intermediates #
  56.  
    # that you need for the backward pass should be stored in the cache #
  57.  
    # variable. #
  58.  
    # #
  59.  
    # You should also use your computed sample mean and variance together #
  60.  
    # with the momentum variable to update the running mean and running #
  61.  
    # variance, storing your result in the running_mean and running_var #
  62.  
    # variables. #
  63.  
    # #
  64.  
    # Note that though you should be keeping track of the running #
  65.  
    # variance, you should normalize the data based on the standard #
  66.  
    # deviation (square root of variance) instead! #
  67.  
    # Referencing the original paper (https://arxiv.org/abs/1502.03167) #
  68.  
    # might prove to be helpful. #
  69.  
    #######################################################################
  70.  
     
  71.  
    #公式: https://arxiv.org/abs/1502.03167
  72.  
    mean_x = np.mean(x, axis = 0 )
  73.  
    var_x = np.var(x, axis = 0)
  74.  
    x_hat =( x - mean_x) / np.sqrt(var_x + eps )
  75.  
    out = gamma* x_hat + beta
  76.  
    running_mean = momentum * running_mean + (1 - momentum) * mean_x
  77.  
    running_var = momentum * running_var + (1 - momentum) * var_x
  78.  
    inv_var_x = 1 / np.sqrt(var_x + eps)
  79.  
    cache =(x,x_hat,gamma,mean_x,inv_var_x)
  80.  
    #######################################################################
  81.  
    # END OF YOUR CODE #
  82.  
    #######################################################################
  83.  
    elif mode == 'test':
  84.  
    #######################################################################
  85.  
    # TODO: Implement the test-time forward pass for batch normalization. #
  86.  
    # Use the running mean and variance to normalize the incoming data, #
  87.  
    # then scale and shift the normalized data using gamma and beta. #
  88.  
    # Store the result in the out variable. #
  89.  
    #######################################################################
  90.  
     
  91.  
    x_hat =( x - running_mean) / np.sqrt(running_var + eps )
  92.  
    out = gamma* x_hat + beta
  93.  
     
  94.  
     
  95.  
    #######################################################################
  96.  
    # END OF YOUR CODE #
  97.  
    #######################################################################
  98.  
    else:
  99.  
    raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
  100.  
     
  101.  
    # Store the updated running means back into bn_param
  102.  
    bn_param['running_mean'] = running_mean
  103.  
    bn_param['running_var'] = running_var
  104.  
     
  105.  
    return out, cache
  106.  
     
  107.  
     
  108.  
    def batchnorm_backward(dout, cache):
  109.  
    """
  110.  
    Backward pass for batch normalization.
  111.  
     
  112.  
    For this implementation, you should write out a computation graph for
  113.  
    batch normalization on paper and propagate gradients backward through
  114.  
    intermediate nodes.
  115.  
     
  116.  
    Inputs:
  117.  
    - dout: Upstream derivatives, of shape (N, D)
  118.  
    - cache: Variable of intermediates from batchnorm_forward.
  119.  
     
  120.  
    Returns a tuple of:
  121.  
    - dx: Gradient with respect to inputs x, of shape (N, D)
  122.  
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
  123.  
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
  124.  
    """
  125.  
    dx, dgamma, dbeta = None, None, None
  126.  
    ###########################################################################
  127.  
    # TODO: Implement the backward pass for batch normalization. Store the #
  128.  
    # results in the dx, dgamma, and dbeta variables. #
  129.  
    # Referencing the original paper (https://arxiv.org/abs/1502.03167) #
  130.  
    # might prove to be helpful. #
  131.  
    ###########################################################################
  132.  
    # =============================================================================
  133.  
    # xi ----- uB----- o^2 B------------xi^--------------yi----------l
  134.  
    # xi-----
  135.  
    # ub---- gamma--
  136.  
    # xi---- betla--
  137.  
    # =============================================================================
  138.  
    x, x_hat, gamma, mu, inv_sigma = cache
  139.  
    x,x_hat,gamma,mean_x,inv_var_x = cache
  140.  
    N = x.shape[0]
  141.  
    # dx 求导合并:
  142.  
    #1: l--->xi^--->xi
  143.  
    dx= gamma * dout * inv_var_x
  144.  
    #2: l----> o^2 B--->xi
  145.  
    dx += (2 / N) * (x - mean_x) * np.sum(- (1/2) * inv_var_x ** 3 * (x - mean_x) * gamma * dout, axis=0)
  146.  
     
  147.  
    #3: l----> uB--->xi
  148.  
    dx += (1 / N) * np.sum(-1 * inv_var_x * gamma * dout, axis=0)
  149.  
     
  150.  
    # dgamma求导:l----> yi--->gamma
  151.  
    dgamma = np.sum(x_hat * dout, axis=0)
  152.  
     
  153.  
    # dbeta求导:l----> yi--->betla
  154.  
    dbeta = np.sum(dout, axis=0)
  155.  
     
  156.  
     
  157.  
     
  158.  
     
  159.  
    ###########################################################################
  160.  
    # END OF YOUR CODE #
  161.  
    ###########################################################################
  162.  
     
  163.  
    return dx, dgamma, dbeta

 batchnorm_backward_alt 

代码复现如下:

  1.  
    def batchnorm_backward_alt(dout, cache):
  2.  
    """
  3.  
    Alternative backward pass for batch normalization.
  4.  
     
  5.  
    For this implementation you should work out the derivatives for the batch
  6.  
    normalizaton backward pass on paper and simplify as much as possible. You
  7.  
    should be able to derive a simple expression for the backward pass.
  8.  
    See the jupyter notebook for more hints.
  9.  
     
  10.  
    Note: This implementation should expect to receive the same cache variable
  11.  
    as batchnorm_backward, but might not use all of the values in the cache.
  12.  
     
  13.  
    Inputs / outputs: Same as batchnorm_backward
  14.  
    """
  15.  
    dx, dgamma, dbeta = None, None, None
  16.  
    ###########################################################################
  17.  
    # TODO: Implement the backward pass for batch normalization. Store the #
  18.  
    # results in the dx, dgamma, and dbeta variables. #
  19.  
    # #
  20.  
    # After computing the gradient with respect to the centered inputs, you #
  21.  
    # should be able to compute gradients with respect to the inputs in a #
  22.  
    # single statement; our implementation fits on a single 80-character line.#
  23.  
    ###########################################################################
  24.  
     
  25.  
    x, x_hat, gamma, mean_x,inv_var_x = cache
  26.  
    N = x.shape[0]
  27.  
    dbeta = np.sum(dout, axis=0)
  28.  
    dgamma = np.sum(x_hat * dout, axis=0)
  29.  
    dxhat = dout * gamma
  30.  
    dx = (1. / N) * inv_var_x * (N * dxhat - np.sum(dxhat, axis=0) -
  31.  
    x_hat * np.sum(dxhat * x_hat, axis=0))
  32.  
     
  33.  
    ###########################################################################
  34.  
    # END OF YOUR CODE #
  35.  
    ###########################################################################
  36.  
     
  37.  
    return dx, dgamma, dbeta

layer normalization:

  1.  
    def layernorm_forward(x, gamma, beta, ln_param):
  2.  
    """
  3.  
    Forward pass for layer normalization.
  4.  
     
  5.  
    During both training and test-time, the incoming data is normalized per data-point,
  6.  
    before being scaled by gamma and beta parameters identical to that of batch normalization.
  7.  
     
  8.  
    Note that in contrast to batch normalization, the behavior during train and test-time for
  9.  
    layer normalization are identical, and we do not need to keep track of running averages
  10.  
    of any sort.
  11.  
     
  12.  
    Input:
  13.  
    - x: Data of shape (N, D)
  14.  
    - gamma: Scale parameter of shape (D,)
  15.  
    - beta: Shift paremeter of shape (D,)
  16.  
    - ln_param: Dictionary with the following keys:
  17.  
    - eps: Constant for numeric stability
  18.  
     
  19.  
    Returns a tuple of:
  20.  
    - out: of shape (N, D)
  21.  
    - cache: A tuple of values needed in the backward pass
  22.  
    """
  23.  
    out, cache = None, None
  24.  
    eps = ln_param.get('eps', 1e-5)
  25.  
    ###########################################################################
  26.  
    # TODO: Implement the training-time forward pass for layer norm. #
  27.  
    # Normalize the incoming data, and scale and shift the normalized data #
  28.  
    # using gamma and beta. #
  29.  
    # HINT: this can be done by slightly modifying your training-time #
  30.  
    # implementation of batch normalization, and inserting a line or two of #
  31.  
    # well-placed code. In particular, can you think of any matrix #
  32.  
    # transformations you could perform, that would enable you to copy over #
  33.  
    # the batch norm code and leave it almost unchanged? #
  34.  
    ###########################################################################
  35.  
    #x: (N, D) ---->(D,N)
  36.  
    x = x.T
  37.  
    mean_x = np.mean(x,axis =0)
  38.  
    var_x= np.var(x,axis = 0)
  39.  
    inv_var_x = 1 / np.sqrt(var_x + eps)
  40.  
     
  41.  
    x_hat = (x - mean_x)/np.sqrt(var_x + eps) #(D,N)
  42.  
    x_hat = x_hat.T #(D,N)---->(N,D)
  43.  
    # gamma: (D,) beta: (D,)
  44.  
    out = gamma * x_hat + beta
  45.  
    cache =(x_hat,gamma,mean_x,inv_var_x)
  46.  
     
  47.  
     
  48.  
    ###########################################################################
  49.  
    # END OF YOUR CODE #
  50.  
    ###########################################################################
  51.  
    return out, cache
  1.  
    def layernorm_backward(dout, cache):
  2.  
    """
  3.  
    Backward pass for layer normalization.
  4.  
     
  5.  
    For this implementation, you can heavily rely on the work you've done already
  6.  
    for batch normalization.
  7.  
     
  8.  
    Inputs:
  9.  
    - dout: Upstream derivatives, of shape (N, D)
  10.  
    - cache: Variable of intermediates from layernorm_forward.
  11.  
     
  12.  
    Returns a tuple of:
  13.  
    - dx: Gradient with respect to inputs x, of shape (N, D)
  14.  
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
  15.  
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
  16.  
    """
  17.  
    dx, dgamma, dbeta = None, None, None
  18.  
    ###########################################################################
  19.  
    # TODO: Implement the backward pass for layer norm. #
  20.  
    # #
  21.  
    # HINT: this can be done by slightly modifying your training-time #
  22.  
    # implementation of batch normalization. The hints to the forward pass #
  23.  
    # still apply! #
  24.  
    ###########################################################################
  25.  
     
  26.  
    x, x_hat, gamma, mean_x,inv_var_x = cache
  27.  
    d = x.shape[0]
  28.  
    dbeta = np.sum(dout, axis=0)
  29.  
    dgamma = np.sum(x_hat * dout, axis=0)
  30.  
    dxhat = dout * gamma
  31.  
    dxhat = dxhat.T
  32.  
    x_hat = x_hat.T
  33.  
    dx = (1. / d) * inv_var_x * (d * dxhat - np.sum(dxhat, axis=0) -
  34.  
    x_hat * np.sum(dxhat * x_hat, axis=0))
  35.  
    dx = dx.T
  36.  
    ###########################################################################
  37.  
    # END OF YOUR CODE #
  38.  
    ###########################################################################
  39.  
    return dx, dgamma, dbeta

 

 

https://github.com/duanzhihua

posted on 2019-09-26 15:17  曹明  阅读(1459)  评论(0编辑  收藏  举报