ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
class BatchNormalization(Module):
    def __init__(self, in_feature, momentum=0.9, eps=1e-8):
        self.mu = 0
        self.var = 1
        self.momentum = momentum
        self.eps = eps
        self.in_feature = in_feature
        self.gamma = Parameter(np.ones(in_feature))
        self.beta = Parameter(np.zeros(in_feature))
        
    def forward(self, x):
        
        if not self.train_mode:
            y = (x - self.mu) / np.sqrt(self.var + self.eps)
            return y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
        
        self.b_mu = np.mean(x, axis=(0, 2, 3), keepdims=True)
        self.b_var = np.var(x, axis=(0, 2, 3), keepdims=True)
        self.y = (x - self.b_mu) / np.sqrt(self.b_var + self.eps)
        self.mu = self.b_mu * self.momentum + self.mu * (1 - self.momentum)
        
        n = x.size / x.shape[1]
        unbiased_var = self.b_var * n / (n - 1)
        self.var = unbiased_var * self.momentum + self.var * (1 - self.momentum)
        return self.y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1)
    
    def backward(self, G):
        self.gamma.delta = np.sum(G * self.y, axis=(0, 2, 3))
        self.beta.delta = np.sum(G, axis=(0, 2, 3))
        return G * self.gamma.value.reshape(1, -1, 1, 1) / np.sqrt(self.b_var + self.eps)
View Code
class Module:
    def __init__(self, name):
        self.name = name
        self.train_mode = False
        
    def __call__(self, *args):
        return self.forward(*args)
    
    def train(self):
        self.train_mode = True
        for m in self.modules():
            m.train()
        
    def eval(self):
        self.train_mode = False
        for m in self.modules():
            m.eval()
        
    def modules(self):
        ms = []
        for attr in self.__dict__:
            m = self.__dict__[attr]
            if isinstance(m, Module):
                ms.append(m)
        return ms
    
    def params(self):
        ps = []
        for attr in self.__dict__:
            p = self.__dict__[attr]
            if isinstance(p, Parameter):
                ps.append(p)
            
        ms = self.modules()
        for m in ms:
            ps.extend(m.params())
        return ps
    
    def info(self, n):
        ms = self.modules()
        output = f"{self.name}\n"
        for m in ms:
            output += ('  '*(n+1)) + f"{m.info(n+1)}\n"
        return output[:-1]
    
    def __repr__(self):
        return self.info(0)
View Code

 

posted on 2022-10-28 09:52  哦哟这个怎么搞  阅读(12)  评论(0编辑  收藏  举报