二值神经网络实现手写数字识别
二值神经网络
概念
二值神经网络(Binary Neural Networks,BNN)是一种特殊的人工神经网络,在该网络中,权重和激活函数的值被约束为二进制(通常是+1或-1),而不是传统的浮点数。这种二值化的处理可以减少网络模型的存储需求和计算复杂度,从而在某些情况下提高了运行速度和效率。
代替CNN实现手写数字识别
import torch
import torch.nn as nn
import math
from torch.autograd import Function
import torch.nn.functional as F
from torchvision import datasets,transforms
from tqdm import tqdm
class Binarize(Function):
@staticmethod
def forward(ctx,input):
input.cuda()
ctx.save_for_backward(input)
return torch.sign(input+1e-20)
@staticmethod
def backward(ctx,grad_output):
input = ctx.saved_tensors[0]
grad_output[input>1] = 0
grad_output[input<-1] = 0
return grad_output
class BinarizedLinear(nn.Module):
def __init__(self,
in_features,
out_features,
binarize_input=True):
super(BinarizedLinear,self).__init__()
self.binarize_input = binarize_input
self.weight = nn.Parameter(
torch.Tensor(out_features,in_features)
)
nn.init.kaiming_uniform_(
self.weight,a=math.sqrt(5)
)
def forward(self,x):
if self.binarize_input:
x = Binarize.apply(x)
w = Binarize.apply(self.weight).cuda()
out = torch.matmul(x,w.t())
return out
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
LR = 0.01
EPOCH = 100
LOG_INTERVAL = 100
model = nn.Sequential(
BinarizedLinear(784,2048,False),
nn.BatchNorm1d(2048),
BinarizedLinear(2048,2048),
nn.BatchNorm1d(2048),
BinarizedLinear(2048,2048),
nn.Dropout(0.5),
nn.BatchNorm1d(2048),
nn.Linear(2048,10)
).to("cuda")
optimizer = torch.optim.Adam(
model.parameters(),lr=LR
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,step_size=40,gamma=0.1
)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'./data',train = True,download = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3801,))
])
),
batch_size = TRAIN_BATCH_SIZE,shuffle=True
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'./data',train = False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3801,))
])
),
batch_size = TEST_BATCH_SIZE,shuffle=False
)
for epoch in range(EPOCH):
for idx,(data,label) in tqdm(enumerate(train_loader),total = len(train_loader)):
# data.to("cuda")
# label.to("cuda")
model.train()
optimizer.zero_grad()
output = model(data.view(-1,28*28).cuda()).cuda()
loss = F.cross_entropy(output.cuda(),label.cuda()).cuda()
loss.backward()
optimizer.step()
for p in model.parameters():
p.data.clamp_(-1,1)
correct_num = 0
total_num = 0
with torch.no_grad():
for data,label in test_loader:
model.eval()
output = model(data.view(-1,28*28).cuda()).cuda()
pred = output.max(1)[1].cuda()
correct_num += (pred.cuda() == label.cuda()).sum().item()
total_num += len(data)
acc = correct_num/total_num
if idx % LOG_INTERVAL == 0:
print('...Testing @ Epoch %03d\tAcc:%.4f' %(
epoch,acc
))
print(loss.item())
scheduler.step() #更新学习率