Accelerating Deep Learning by Focusing on the Biggest Losers

Accelerating Deep Learning by Focusing on the Biggest Losers

思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

在这里插入图片描述

相关工作

作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.

作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.

主要内容

在这里插入图片描述

在这里插入图片描述
算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:

\[\max \{(k/n)^\beta, s\} \]

其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.
我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.

graph LR A[样本x] --> C(网络f) C --> D[损失l] D--更新-->E[损失库] D-->F[计算概率] F-->G(形成batch) G--反向传递-->C E-->F

从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).

在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.

代码

因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.

"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
"""


class Unit:

    def __init__(self, command, type=str,
                    default=None):
        if default is None:
            default = type()
        self.command = command
        self.type = type
        self.default = default

class Opi:
    """
    >>> parser = Opi()
    >>> parser.add_opt(command="lr", type=float)
    >>> parser.add_opt(command="epochs", type=int)
    """
    def __init__(self):
        self.store = []
        self.infos = {}

    def add_opt(self, **kwargs):
        self.store.append(
            Unit(**kwargs)
        )

    def acquire(self):
        s = "Acquire args {0.command} [" \
            "type:{0.type.__name__} " \
            "default:{0.default}] : "
        for unit in self.store:
            while True:
                inp = input(s.format(
                    unit
                ))
                try:
                    if inp: #若有输入
                        inp = unit.type(inp)
                    else:
                        inp = unit.default
                    self.infos.update(
                        {unit.command:inp}
                    )
                    self.__setattr__(unit.command, inp)
                    break
                except:
                    print("Type {0} should be given".format(
                        unit.type.__name__
                    ))


if __name__ == "__main__":
    parser = Opi()
    parser.add_opt(command = "x", type=int)
    parser.add_opt(command="y", type=str)
    parser.acquire()
    print(parser.infos)
    print(parser.x)
'''
calcprob.py
计算概率
'''




import collections



class Calcprob:
    def __init__(self, beta, sample_min, max_len=3000):
        assert 0. <= sample_min <= 1., "Invalid sample_min"
        assert beta > 0, "Invalid beta"
        self.beta = beta
        self.sample_min = sample_min
        self.max_len = max_len
        self.history = collections.deque(maxlen=max_len)
        self.num_slot = 1000
        self.hist = [0] * self.num_slot
        self.count = 0

    def update_history(self, losses):
        """
        BoundedHistogram
        :param losses:
        :return:
        """
        for loss in losses:
            assert loss > 0
            if self.count is self.max_len:
                loss_old = self.history.popleft()
                slot_old = int(loss_old * self.num_slot) % self.num_slot
                self.hist[slot_old] -= 1
            else:
                self.count += 1
                self.history.append(loss)
            slot = int(loss * self.num_slot) % self.num_slot
            self.hist[slot] += 1

    def get_probability(self, loss):
        assert loss > 0
        slot = int(loss * self.num_slot) % self.num_slot
        prob = sum(self.hist[:slot]) / self.count
        assert isinstance(prob, float), "int division error..."
        return prob ** self.beta

    def calc_probability(self, losses):
        if isinstance(losses, float):
            losses =  (losses, )
        self.update_history(losses)
        probs = (
            max(
                self.get_probability(loss),
                self.sample_min
            )
            for loss in losses
        )
        return probs

    def __call__(self, losses):
        return self.calc_probability(losses)


if __name__ == "__main__":
    pass


'''
selector.py
'''


import calcprob
import numpy as np


class Selector:

    def __init__(self, batch_size,
                 beta, sample_min, max_len=3000):
        self.batch_size = batch_size
        self.calcprob = calcprob.Calcprob(beta,
                                          sample_min,
                                          max_len)
        self.reset()

    def backward(self):
        loss = sum(self.batch)
        loss.backward()
        self.reset()

    def reset(self):
        self.batch = []
        self.length = 0.

    def select(self, losses):
        probs = self.calcprob(losses)
        for i, prob in enumerate(probs):
            if np.random.rand() < prob:
                self.batch.append(losses[i])
                self.length += 1
                if self.length >= self.batch_size:
                    self.backward()

    def __call__(self, losses):
        self.select(losses)


'''
main.py
'''

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os


import selector





class Train:

    def __init__(self, model, lossfunc,
                 bpsize, beta, sample_min, max_len=3000,
                 lr=0.01, momentum=0.9, weight_decay=0.0001):
        self.net = self.choose_net(model)
        self.criterion = self.choose_lossfunc(lossfunc)
        self.opti = torch.optim.SGD(self.net.parameters(),
                                    lr=lr, momentum=momentum,
                                    weight_decay=weight_decay)
        self.selector = selector.Selector(bpsize, beta,
                                          sample_min, max_len)
        self.gpu()
        self.generate_path()
        self.acc_rates = []
        self.errors = []

    def choose_net(self, model):
        net = getattr(
            torchvision.models,
            model,
            None
        )
        if net is None:
            raise ValueError("no such model")
        return net()

    def choose_lossfunc(self, lossfunc):
        lossfunc = getattr(
            nn,
            lossfunc,
            None
        )
        if lossfunc is None:
            raise ValueError("no such lossfunc")
        return lossfunc



    def gpu(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if torch.cuda.device_count() > 1:
            print("Let'us use %d GPUs" % torch.cuda.device_count())
            self.net = nn.DataParallel(self.net)
        self.net = self.net.to(self.device)



    def generate_path(self):
        """
        生成保存数据的路径
        :return:
        """
        try:
            os.makedirs('./paras')
            os.makedirs('./logs')
            os.makedirs('./infos')
        except FileExistsError as e:
            pass
        name = self.net.__class__.__name__
        paras = os.listdir('./paras')
        logs = os.listdir('./logs')
        infos = os.listdir('./infos')
        number = max((len(paras), len(logs), len(infos)))
        self.para_path = "./paras/{0}{1}.pt".format(
            name,
            number
        )

        self.log_path = "./logs/{0}{1}.txt".format(
            name,
            number
        )
        self.info_path = "./infos/{0}{1}.npy".format(
            name,
            number
        )


    def log(self, strings):
        """
        运行日志
        :param strings:
        :return:
        """
        # a 往后添加内容
        with open(self.log_path, 'a', encoding='utf8') as f:
            f.write(strings)

    def save(self):
        """
        保存网络参数
        :return:
        """
        torch.save(self.net.state_dict(), self.para_path)

    def derease_lr(self, multi=0.96):
        """
        降低学习率
        :param multi:
        :return:
        """
        self.opti.param_groups[0]['lr'] *= multi


    def train(self, trainloder, epochs=50):
        data_size = len(trainloder) * trainloder.batch_size
        part = int(trainloder.batch_size / 2)
        for epoch in range(epochs):
            running_loss = 0.
            total_loss = 0.
            acc_count = 0.
            if (epoch + 1) % 8 is 0:
                self.derease_lr()
                self.log(#日志记录
                    "learning rate change!!!\n"
                )
            for i, data in enumerate(trainloder):
                imgs, labels = data
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                out = self.net(imgs)
                _, pre = torch.max(out, 1)  #判断是否判断正确
                acc_count += (pre == labels).sum().item() #加总对的个数

                losses = (
                    self.criterion(out[i], labels[i])
                    for i in range(len(labels))
                )

                self.opti.zero_grad()
                self.selector(losses) #选择
                self.opti.step()

                running_loss += sum(losses).item()

                if (i+1) % part is 0:
                    strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
                        epoch, i, running_loss / part
                    )
                    self.log(strings)#日志记录
                    total_loss += running_loss
                    running_loss = 0.
            self.acc_rates.append(acc_count / data_size)
            self.errors.append(total_loss / data_size)
            self.log( #日志记录
                "Accuracy of the network on %d train images: %d %%\n" %(
                    data_size, acc_count / data_size * 100
                )
            )
            self.save() #保存网络参数
        #保存一些信息画图用
        np.save(self.info_path, {
            'acc_rates': np.array(self.acc_rates),
            'errors': np.array(self.errors)
        })




if __name__ == "__main__":

    import OptInput
    args = OptInput.Opi()
    args.add_opt(command="model", default="resnet34")
    args.add_opt(command="lossfunc", default="CrossEntropyLoss")
    args.add_opt(command="bpsize", default=32)
    args.add_opt(command="beta", default=0.9)
    args.add_opt(command="sample_min", default=0.3)
    args.add_opt(command="max_len", default=3000)
    args.add_opt(command="lr", default=0.001)
    args.add_opt(command="momentum", default=0.9)
    args.add_opt(command="weight_decay", default=0.0001)

    args.acquire()

    root = "C:/Users/pkavs/1jupiterdata/data"

    trainset = torchvision.datasets.CIFAR10(root=root, train=True,
                                          download=False,
                                          transform=transforms.Compose(
                                              [transforms.Resize(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
                                          ))

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=8,
                                               pin_memory=True)



    dog = Train(**args.infos)
    dog.train(train_loader, epochs=1000)





posted @ 2020-02-16 21:24  馒头and花卷  阅读(350)  评论(0编辑  收藏  举报