pytorch中调用C进行扩展

pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;

第一步:编写头文件

/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

 

第二步:编写源文件

/* src/my_lib.c */
#include <TH/TH.h>

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
    if (!THFloatTensor_isSameSizeAs(input1, input2))
        return 0;
    THFloatTensor_resizeAs(output, input1);
    THFloatTensor_cadd(output, input1, 1.0, input2);
    return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
    THFloatTensor_resizeAs(grad_input, grad_output);
    THFloatTensor_fill(grad_input, 1);
    return 1;
}

 

注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

 

 第三步:在同级目录下创建一个.py文件(比如叫“build.py”)

该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);

# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib',        # 输出文件地址及名称
headers='src/my_lib.h',    # 编译.h文件地址及名称
sources=['src/my_lib.c'],  # 编译.c文件地址及名称
with_cuda=False            # 不使用cuda
)
ffi.build()

 

第四步:编写.py脚本调用编译好的C扩展模块

import torch
from torch.autograd import Function
from _ext import my_lib
import torch.nn as nn

class MyAddFunction(Function):
    def forward(self, input1, input2):
        output = torch.FloatTensor()
        my_lib.my_lib_add_forward(input1, input2, output)
        return output

    def backward(self, grad_output):
        grad_input = torch.FloatTensor()
        my_lib.my_lib_add_backward(grad_input, grad_output)
        return grad_input

class MyAddModule(nn.Module):
    def forward(self, input1, input2):
        return MyAddFunction()(input1, input2)

class MyNetWork(nn.Module):
    def __init__(self):
        super(MyNetWork, self).__init__()
        self.add = MyAddModule()

    def forward(self, input1, input2):
        return self.add(input1, input2)

model = MyNetWork()
input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
print(model(input1, input2))
print(input1 + input2)

 

至此,用这个简单的例子抛砖引玉~

 

posted @ 2019-11-14 15:15  outthinker  阅读(1596)  评论(0编辑  收藏  举报