声明的类模板中成员函数、成员模板函数、友元函数、友元模板函数的定义方式
点击查看代码
template<typename T>
class TensorBase {
public:
void init(std::string name = NULL, int num = 0, int nbDims = 1, std::vector<int> dimA = {1});
static TensorBase<T>* getInstance() {
if(instance == NULL)
instance = new TensorBase<T>();
return instance;
}
T* dev_malloc(int nums);
T* host_malloc(int nums);
public:
int size() const;
template<typename T2, typename...Args>
friend void print_display(const T2 &val, const Args&...rest);
template<typename Tc>
friend int compare(const Tc &val1, const Tc & val2);
void print_stride() {
for(auto x : strideA_m) {
std::cout<<x<<' ';
}
}
private:
std::string name_m;
int num_m;
int nbDims_m;
std::vector<int> dimA_m;
std::vector<int> strideA_m;
T* dev_ptr;
T* host_ptr;
static TensorBase<T> *instance;
};
template<typename T>
TensorBase<T>* TensorBase<T>::instance = NULL;
template<typename T>
void TensorBase<T>::init(std::string name, int num, int nbDims, std::vector<int> dimA) {
int nums = 1;
for(auto x : dimA) {
nums *= x;
}
if(dimA.size() != nbDims || nums != num) {
std::cout << "input param error!" << std::endl;
return;
}
name_m = name;
num_m = num;
nbDims_m = nbDims;
strideA_m.clear();
dimA_m.assign(dimA.begin(), dimA.end());
for(int i = 0; i < nbDims; i++) {
if(dimA.size() == 1) {
strideA_m.push_back(1);
break;
}
int stride_cal = 1, j = 0;
for(j = i + 1; j < nbDims; j++) {
stride_cal *= dimA[j];
}
strideA_m.push_back(stride_cal);
}
}
template<typename T>
int TensorBase<T>::size() const{
return num_m;
}
template<typename T>
T* TensorBase<T>::dev_malloc(int nums) {
dev_ptr = (T*) malloc(nums * sizeof(T));
memset(dev_ptr, 0, nums);
return dev_ptr;
}
template<typename T>
T* TensorBase<T>::host_malloc(int nums) {
host_ptr = (T*) malloc(nums * sizeof(T));
memset(dev_ptr, 0, nums);
return host_ptr;
}
inline void print_display() {}
template<typename T2, typename...Args>
void print_display(const T2 &val, const Args&...rest) {
std::cout << val << ' ';
print_display(rest...);
}
template< typename...Args>
void print_display(const std::vector<int> &val, const Args&...rest) {
for(auto x : val) { std::cout << x <<' '; }
print_display(rest...);
}
template< typename...Args>
void print_display(const std::vector<float> &val, const Args&...rest) {
for(auto x : val) { std::cout << x <<' '; }
print_display(rest...);
}
template<typename To>
std::ostream& operator<< (std::ostream& out , std::vector<To> &arr) {
for(auto x : arr) { out<< x << ' '; }
return out;
}
template<typename Tc>
int compare(const Tc &val1, const Tc & val2) {
return val1 == val2;
}
注: **友元函数模板的模板类型需与类模板不同**
---
此代码中包含可变参数模板的使用方法,因可变参数是递归调用的,因此需要在定义最后一层递归的函数实现【==最后一次函数实现最好定义为内联,否则可能会出现重定义问题==】
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 张高兴的大模型开发实战:(一)使用 Selenium 进行网页爬虫
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构