ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::

pybind11.hpp

pybind11.cpp

#include "pybind11.hpp"
#include <stdio.h>
#include <iostream>

using namespace std;
namespace py = pybind11;

class Animal{
public:
    Animal(const string& name){
        this->name_ = name;
    }

    virtual string go(int repeats) = 0;

protected:
    string name_;
};

class Dog : public Animal{
public:
    Dog(const string& name): Animal(name){
    }

    virtual ~Dog(){
        printf("析构了狗狗.~\n");
    }

    virtual string go(int repeats) override{
        string result;
        for(int i = 0; i < repeats; ++i){
            result += "wang! ";
        }
        return this->name_ + " ::: " + result;
    }
};

string call_animal_go(Animal* animal, int repeats){
    printf("Execute call_animal_go\n");
    return animal->go(repeats);
}

int add(int a, int b){
    printf("add a = %d, b = %d\n", a, b);
    return a + b;
}

// 对numpy的操作
void print_ndarray(const py::array& arra){

    printf("ndim = %d\n", arra.ndim());
    printf("size = %d\n", arra.size());

    for(int i = 0; i < arra.ndim(); ++i)
        printf("arra.shape[%d] = %d\n", i, arra.shape(i));
    
    float* ptr = (float*)arra.data<float>(0);
    for(int i = 0; i < arra.size(); ++i){
        printf("%f ", ptr[i]);
    }
    printf("\n");
}

//PYBIND11_MODULE(name, variable){
//}
PYBIND11_MODULE(sb, m) {

    m.doc() = R"doc(
        这里是介绍 ~~~
    )doc";

    m.attr("name") = "小王";
    m.def("add", &add, "加法函数,实现两个数的加法", py::arg("a"), py::arg("b")=0);

    // 声明基类,基类是抽象类,没有构造函数
    py::class_<Animal>(m, "Animal")
        .def("go", &Animal::go);

    // 声明子类,但是子类有构造函数,并有一个参数
    py::class_<Dog, Animal>(m, "Dog")
        .def(py::init<const string&>());

    // torch._C
    m.def("call_animal_go", &call_animal_go, "调用Animal的go方法");
    m.def("print_ndarray", &print_ndarray, "打印ndarray里边的信息", py::arg("arra"));
}
View Code

pybind11.test

import sb
import numpy as np

print(sb.name)
print(sb.__doc__)

print(sb.add.__doc__)
print(sb.add(5))

dog = sb.Dog("小狗")
print(dog.go(3))
print(sb.call_animal_go(dog, 3))

# shared_ptr   引用计数,  循环引用(导致两个或多个对象都无法达成计数为0,无法卸载),使用弱引用解决
# 引用计数为0时,释放对象
# python的所有对象,都是引用计数的技术实现的
# weak_ptr     弱引用     不占用计数的引用 , 需要用的时候提申请(如果对象没了,就返回false)
a = sb.Dog("别的东西")
b = a
dog = None
print("Step~~~~~~~")

arra = np.arange(25).reshape(5, 5).astype(np.float32)
sb.print_ndarray(arra)
View Code

 

posted on 2022-11-30 10:03  哦哟这个怎么搞  阅读(50)  评论(0编辑  收藏  举报