成员模板函数

声明的类模板中成员函数、成员模板函数、友元函数、友元模板函数的定义方式

点击查看代码
template<typename T>
class TensorBase {
 public:
    void init(std::string name = NULL, int num = 0, int nbDims = 1, std::vector<int> dimA = {1});
    // singleton pattern
    static TensorBase<T>* getInstance() {
        if(instance == NULL)
            instance = new TensorBase<T>();
        return instance;
    }
    T* dev_malloc(int nums);
    T* host_malloc(int nums);

 public:
    int size() const;
    template<typename T2, typename...Args>
    friend void print_display(const T2 &val, const Args&...rest);
    template<typename Tc>
    friend int compare(const  Tc &val1, const Tc & val2);
    void print_stride() {
        for(auto x : strideA_m) {
            std::cout<<x<<' ';
        }
    }
 private:
    std::string name_m;
    int num_m;
    int nbDims_m;
    std::vector<int> dimA_m;
    std::vector<int> strideA_m;
    T* dev_ptr;
    T* host_ptr;
    static TensorBase<T> *instance;
};
/* data initialize */
template<typename T>
TensorBase<T>* TensorBase<T>::instance = NULL;

template<typename T>
void TensorBase<T>::init(std::string name, int num, int nbDims, std::vector<int> dimA) {
    int nums = 1;
    for(auto x : dimA) {
        nums *= x;
    }
    if(dimA.size() != nbDims || nums != num) {
        std::cout << "input param error!" << std::endl;
        return;
    }
    name_m = name;
    num_m = num;
    nbDims_m = nbDims;
    strideA_m.clear();
    dimA_m.assign(dimA.begin(), dimA.end());
    for(int i = 0; i < nbDims; i++) {
        if(dimA.size() == 1) {
            strideA_m.push_back(1);
            break;
        }
        int stride_cal = 1, j = 0;
        for(j = i + 1; j < nbDims; j++) {
            stride_cal *= dimA[j];
            
        }
        strideA_m.push_back(stride_cal);
    }
}

/* size() */
template<typename T>
int TensorBase<T>::size() const{
    return num_m;
}

/* dev_malloc() */
template<typename T>
T* TensorBase<T>::dev_malloc(int nums) { 
    dev_ptr = (T*) malloc(nums * sizeof(T));
    memset(dev_ptr, 0, nums);
    return dev_ptr;
}

/* host_malloc() */
template<typename T>
T* TensorBase<T>::host_malloc(int nums) {
    host_ptr = (T*) malloc(nums * sizeof(T));
    memset(dev_ptr, 0, nums);
    return host_ptr;
}

/* Variadic Function Template */
inline void print_display() {}
template<typename T2, typename...Args>
void print_display(const T2 &val, const Args&...rest) {
    std::cout << val << ' ';
    print_display(rest...);
}

/* vector<int> Partial specialization */
template< typename...Args> 
void print_display(const std::vector<int> &val, const Args&...rest) {
    for(auto x : val) { std::cout << x <<' '; }
    print_display(rest...);
}

/* vector<float> Partial specialization */
template< typename...Args> 
void print_display(const std::vector<float> &val, const Args&...rest) {
    for(auto x : val) { std::cout << x <<' '; }
    print_display(rest...);
}

/* operator overide */
template<typename To>
std::ostream& operator<< (std::ostream& out , std::vector<To> &arr) {
    for(auto x : arr) { out<< x << ' '; }
    return out;
}
/* compare() */
template<typename Tc>
int compare(const  Tc &val1, const Tc & val2) {
    return  val1 == val2;
}


注: **友元函数模板的模板类型需与类模板不同** --- 此代码中包含可变参数模板的使用方法,因可变参数是递归调用的,因此需要在定义最后一层递归的函数实现【==最后一次函数实现最好定义为内联,否则可能会出现重定义问题==】
posted @   xing_l  阅读(170)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 张高兴的大模型开发实战:(一)使用 Selenium 进行网页爬虫
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
点击右上角即可分享
微信分享提示