class BatchNormalization(Module): def __init__(self, in_feature, momentum=0.9, eps=1e-8): self.mu = 0 self.var = 1 self.momentum = momentum self.eps = eps self.in_feature = in_feature self.gamma = Parameter(np.ones(in_feature)) self.beta = Parameter(np.zeros(in_feature)) def forward(self, x): if not self.train_mode: y = (x - self.mu) / np.sqrt(self.var + self.eps) return y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1) self.b_mu = np.mean(x, axis=(0, 2, 3), keepdims=True) self.b_var = np.var(x, axis=(0, 2, 3), keepdims=True) self.y = (x - self.b_mu) / np.sqrt(self.b_var + self.eps) self.mu = self.b_mu * self.momentum + self.mu * (1 - self.momentum) n = x.size / x.shape[1] unbiased_var = self.b_var * n / (n - 1) self.var = unbiased_var * self.momentum + self.var * (1 - self.momentum) return self.y * self.gamma.value.reshape(1, -1, 1, 1) + self.beta.value.reshape(1, -1, 1, 1) def backward(self, G): self.gamma.delta = np.sum(G * self.y, axis=(0, 2, 3)) self.beta.delta = np.sum(G, axis=(0, 2, 3)) return G * self.gamma.value.reshape(1, -1, 1, 1) / np.sqrt(self.b_var + self.eps)
class Module: def __init__(self, name): self.name = name self.train_mode = False def __call__(self, *args): return self.forward(*args) def train(self): self.train_mode = True for m in self.modules(): m.train() def eval(self): self.train_mode = False for m in self.modules(): m.eval() def modules(self): ms = [] for attr in self.__dict__: m = self.__dict__[attr] if isinstance(m, Module): ms.append(m) return ms def params(self): ps = [] for attr in self.__dict__: p = self.__dict__[attr] if isinstance(p, Parameter): ps.append(p) ms = self.modules() for m in ms: ps.extend(m.params()) return ps def info(self, n): ms = self.modules() output = f"{self.name}\n" for m in ms: output += (' '*(n+1)) + f"{m.info(n+1)}\n" return output[:-1] def __repr__(self): return self.info(0)