在类中使用numba提高python运行速度

背景

众所周知,numba 的 jit 功能可以显著提高 python 的运行速度。
然而,numba 对于程序有一些特定的要求,比如不能用 list 之类的,而且似乎也不能直接放在类(class)里。

要想在类中加速的话,有两种方法:

  1. 官方文档给出一种方法 Compiling python classes with @jitclass
    比较麻烦,需要给出类中成员,而且类中如果有 list 之类的东西就会失效
  2. 在类的外部编写函数,类中调用类外的函数,也即本文内容
    参考 How do I use numba on a member function of a class?

解决方法

对于以下代码(没有 jit 修饰):

class A():
    def __init__(self):
        self.a = 1
        self.n = 10000000

    def calc(self):
        for _ in range(self.n):
            self.a += 1


if __name__ == '__main__':
    start = time.process_time()
    a = A()
    a.calc()
    end = time.process_time()
    print('finish all in %s' % str(end - start))

正常运行,得到运行时间

finish all in 1.546875

若想加速 calc() 函数,直接在类中加上 jit 修饰,会报 warning,并且 numba 退化

    @jit
    def calc(self):
        for _ in range(self.n):
            self.a += 1

改写如下形式:

@jit
def _calc(n, a):  # 在外部写一个函数
    for _ in range(n):
        a += 1
    return a


class A():
    def __init__(self):
        self.a = 1
        self.n = 10000000

    def calc(self):  # 类中调用外部的函数
        self.a = _calc(self.n, self.a)


if __name__ == '__main__':
    start = time.process_time()
    a = A()
    a.calc()
    end = time.process_time()
    print('finish all in %s' % str(end - start))

得到运行时间

finish all in 0.3125

提速大约 5 倍,大成功

posted @ 2021-01-13 09:13  BwShen  阅读(360)  评论(0编辑  收藏  举报