二值神经网络实现手写数字识别

二值神经网络

概念

二值神经网络(Binary Neural Networks,BNN)是一种特殊的人工神经网络,在该网络中,权重和激活函数的值被约束为二进制(通常是+1或-1),而不是传统的浮点数。这种二值化的处理可以减少网络模型的存储需求和计算复杂度,从而在某些情况下提高了运行速度和效率。

代替CNN实现手写数字识别

import torch
import torch.nn as nn
import math
from torch.autograd import Function
import torch.nn.functional as F
from torchvision import datasets,transforms
from tqdm import tqdm
class Binarize(Function):
    @staticmethod
    def forward(ctx,input):
        input.cuda()
        ctx.save_for_backward(input)
        return torch.sign(input+1e-20)
    
    @staticmethod
    def backward(ctx,grad_output):
        input = ctx.saved_tensors[0]
        grad_output[input>1] = 0
        grad_output[input<-1] = 0
        return grad_output
    

class BinarizedLinear(nn.Module):
    def __init__(self,
                in_features,
                out_features,
                binarize_input=True):
        super(BinarizedLinear,self).__init__()
        self.binarize_input = binarize_input
        self.weight = nn.Parameter(
            torch.Tensor(out_features,in_features)
        )
        nn.init.kaiming_uniform_(
            self.weight,a=math.sqrt(5)
        )
        
    def forward(self,x):
        if self.binarize_input:
            x = Binarize.apply(x)
        w = Binarize.apply(self.weight).cuda()
        out = torch.matmul(x,w.t())
        return out

TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
LR = 0.01
EPOCH = 100
LOG_INTERVAL = 100
model = nn.Sequential(
    BinarizedLinear(784,2048,False),
    nn.BatchNorm1d(2048),
    BinarizedLinear(2048,2048),
    nn.BatchNorm1d(2048),
    BinarizedLinear(2048,2048),
    nn.Dropout(0.5),
    nn.BatchNorm1d(2048),
    nn.Linear(2048,10)
).to("cuda")
optimizer = torch.optim.Adam(
    model.parameters(),lr=LR
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,step_size=40,gamma=0.1
)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data',train = True,download = True,
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3801,))
        ])
    ),
    batch_size = TRAIN_BATCH_SIZE,shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data',train = False,
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3801,))
        ])
    ),
    batch_size = TEST_BATCH_SIZE,shuffle=False
)

for epoch in range(EPOCH):
    for idx,(data,label) in tqdm(enumerate(train_loader),total = len(train_loader)):
#         data.to("cuda")
#         label.to("cuda")
        model.train()
        optimizer.zero_grad()
        output = model(data.view(-1,28*28).cuda()).cuda()
        loss = F.cross_entropy(output.cuda(),label.cuda()).cuda()
        loss.backward()
        
        optimizer.step()
        for p in model.parameters():
            p.data.clamp_(-1,1)
        
        correct_num = 0
        total_num = 0
        with torch.no_grad():
            for data,label in test_loader:
                model.eval()
                output = model(data.view(-1,28*28).cuda()).cuda()
                pred = output.max(1)[1].cuda()
                correct_num += (pred.cuda() == label.cuda()).sum().item()
                total_num += len(data)
        
        acc = correct_num/total_num
        if idx % LOG_INTERVAL == 0:
            print('...Testing @ Epoch %03d\tAcc:%.4f' %(
                epoch,acc
            ))
            print(loss.item())
        scheduler.step() #更新学习率
posted @ 2024-05-06 11:19  Sun-Wind  阅读(20)  评论(0编辑  收藏  举报