pytorch中调用C进行扩展
pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;
第一步:编写头文件
/* src/my_lib.h */ int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output); int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
第二步:编写源文件
/* src/my_lib.c */ #include <TH/TH.h> int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output) { if (!THFloatTensor_isSameSizeAs(input1, input2)) return 0; THFloatTensor_resizeAs(output, input1); THFloatTensor_cadd(output, input1, 1.0, input2); return 1; } int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input) { THFloatTensor_resizeAs(grad_input, grad_output); THFloatTensor_fill(grad_input, 1); return 1; }
注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;
第三步:在同级目录下创建一个.py文件(比如叫“build.py”)
该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);
# build.py from torch.utils.ffi import create_extension ffi = create_extension( name='_ext.my_lib', # 输出文件地址及名称 headers='src/my_lib.h', # 编译.h文件地址及名称 sources=['src/my_lib.c'], # 编译.c文件地址及名称 with_cuda=False # 不使用cuda ) ffi.build()
第四步:编写.py脚本调用编译好的C扩展模块
import torch from torch.autograd import Function from _ext import my_lib import torch.nn as nn class MyAddFunction(Function): def forward(self, input1, input2): output = torch.FloatTensor() my_lib.my_lib_add_forward(input1, input2, output) return output def backward(self, grad_output): grad_input = torch.FloatTensor() my_lib.my_lib_add_backward(grad_input, grad_output) return grad_input class MyAddModule(nn.Module): def forward(self, input1, input2): return MyAddFunction()(input1, input2) class MyNetWork(nn.Module): def __init__(self): super(MyNetWork, self).__init__() self.add = MyAddModule() def forward(self, input1, input2): return self.add(input1, input2) model = MyNetWork() input1, input2 = torch.randn(5, 5), torch.randn(5, 5) print(model(input1, input2)) print(input1 + input2)
至此,用这个简单的例子抛砖引玉~