numba-jitclass

参考文档:https://numba.pydata.org/numba-doc/latest/user/jitclass.html#

jitclass

对类进行装饰

import numba as nb
import numpy as np
from numba.experimental import jitclass

spec = [
    ("value", nb.int32),
    ("array", nb.float32[:]),
]


@jitclass(spec)
class Bag:
    def __init__(self, value):
        # self.value = value
        self.array = np.zeros(value, dtype=np.float32)

    @property
    def size(self):
        return self.array.size

    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val

        return self.array


a = 2
b = Bag(a)
print("b.increment(1): ", b.increment(1))  # b.increment(1):  [1. 1.]

上面例子中,spec 提供了一个两元组元素的数组,元组包含字段名称和类型。也可以使用有序字典映射字段与类选的关系。
类中至少要初始化每个定义的字段,如果不初始化,字段会包含垃圾数据。

具体的 numb.typed 容器(container)做类成员

  1. 显式的类型和构建
kv_ty = (nb.types.int64, nb.types.unicode_type)


@jitclass(
    [("d", nb.types.DictType(*kv_ty)), ("l", nb.types.ListType(nb.types.float64))]
)
class ContainerHolder:
    def __init__(self):
        self.d = nb.typed.Dict.empty(*kv_ty)
        self.l = nb.typed.List.empty_list(nb.types.float64)


c = ContainerHolder()
c.d[1] = "apple"
c.d[2] = "orange"
c.l.append(1.0)
c.l.append(2.0)

print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
  1. 另一个有用的模式是用 numba.typed 的 _numba_type_ 属性能够找找到容器的类型, 这样可以直接在 python 解释器中访问容器的实例。使用 numba.typeof 可以得到跟容器实例一样的信息。如下:
d = nb.typed.Dict()
d[1] = "apple"
d[2] = "orange"

l = nb.typed.List()
l.append(1.0)
l.append(2.0)


@jitclass([("d", nb.typeof(d)), ("l", nb.typeof(l))])
class ContainerInsHolder:
    def __init__(self, dict_instance, list_instance):
        self.d = dict_instance
        self.l = list_instance


c = ContainerInsHolder(d, l)
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
c.d[3] = "banana"
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange, 3: banana}

需要注意的是,容器实例在使用前必须要初始化,否则会有异常,如下面的是错误的:

d_ty = nb.types.DictType(nb.types.int64, nb.types.unicode_type)


@jitclass([("d", d_ty)])
class NotInitContainer:
    def __init__(self):
        self.d[10] = "apple"  # d 没有被初始化,这里是无效的


NotInitContainer()  # 实例化会失败,内存访问无效,程序会异常结束 Process finished with exit code -1073741819 (0xC0000005)

支持的操作

以下 jitclasses 操作在 python 解释器和 numba 编译的函数中都支持:

  • 用 jitclass 类实例化对象。(如: bag = Bag(123))
  • 读/写属性。(如:bag.value)
  • 方法调用。(如:bag.increment(2))
  • 调用实例的静态方法。(如:bag.add(1, 2))
  • 调用类的静态方法。(如:Bag.add(1,2))

局限性

  • jitclass 被看作是一个 numba 的编译函数
  • isinstance() 只能在 python 解释器中使用
完整代码
import numba as nb
import numpy as np
from numba.experimental import jitclass

spec = [
    ("value", nb.int32),
    ("array", nb.float32[:]),
]


@jitclass(spec)
class Bag:
    def __init__(self, value):
        # self.value = value
        self.array = np.zeros(value, dtype=np.float32)

    @property
    def size(self):
        return self.array.size

    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val

        return self.array


a = 2
b = Bag(a)
print("b.increment(1): ", b.increment(1))  # b.increment(1):  [1. 1.]

print("-" * 20)
kv_ty = (nb.types.int64, nb.types.unicode_type)


@jitclass(
    [("d", nb.types.DictType(*kv_ty)), ("l", nb.types.ListType(nb.types.float64))]
)
class ContainerHolder:
    def __init__(self):
        self.d = nb.typed.Dict.empty(*kv_ty)
        self.l = nb.typed.List.empty_list(nb.types.float64)


c = ContainerHolder()
c.d[1] = "apple"
c.d[2] = "orange"
c.l.append(1.0)
c.l.append(2.0)

print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]


print("-" * 20)

d = nb.typed.Dict()
d[1] = "apple"
d[2] = "orange"

l = nb.typed.List()
l.append(1.0)
l.append(2.0)


@jitclass([("d", nb.typeof(d)), ("l", nb.typeof(l))])
class ContainerInsHolder:
    def __init__(self, dict_instance, list_instance):
        self.d = dict_instance
        self.l = list_instance


c = ContainerInsHolder(d, l)
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
c.d[3] = "banana"
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange, 3: banana}


print("-" * 20)
d_ty = nb.types.DictType(nb.types.int64, nb.types.unicode_type)


@jitclass([("d", d_ty)])
class NotInitContainer:
    def __init__(self):
        self.d[10] = "apple"  # d 没有被初始化,这里是无效的


NotInitContainer()  # 实例化会失败,内存访问无效,程序会异常结束 Process finished with exit code -1073741819 (0xC0000005)

posted @ 2024-05-12 16:51  一枚码农  阅读(17)  评论(0编辑  收藏  举报