1 #ifndef SMART_PRT_INCLUDE_HEADER 2 #define SMART_PRT_INCLUDE_HEADER 3 #include <cstddef> 4 5 template<class T> 6 class smart_ptr{ 7 public: 8 smart_ptr(T *); 9 smart_ptr(smart_ptr&); 10 ~smart_ptr(); 11 smart_ptr<T>& operator=(smart_ptr&); 12 T* operator->(); 13 T &operator*(); 14 private: 15 T *m_ptr; 16 size_t *m_use; 17 void destroy(); 18 }; 19 20 #endif //SMART_PRT_INCLUDE_HEADER
根据指针的一些定义,我们不难定义上述头文件,最后要的是实现上述类模板的复制控制.为了方面使用,需要对 -> 和 * 操作符重载以保留指针原有的使用方法.把智能指针定义为一个类模板的好处是能够清楚的分辨指针的类型.一般情况下不需要对模板进行特化.但如果发布为库的话则无法实现模板的功能,因此一般会考虑对每个项目直接使用所有类模板的代码.
最后,可以做如下定义:
1 #include "smart_ptr.h" 2 #include <iostream> 3 4 using namespace std; 5 6 template<class T> 7 smart_ptr<T>::smart_ptr(T *ptr) 8 :m_ptr(ptr),m_use(new int(1)){ 9 cout<<"构造函数:"<<endl; 10 cout << "当前 ptr 所在存储区引用数:" << *m_use << endl; 11 } 12 13 template<class T> 14 smart_ptr<T>::smart_ptr(smart_ptr ©):m_ptr(copy.m_ptr),m_use(&(++(*copy.m_use))){ 15 cout<<"复制构造函数:"<<endl; 16 cout << "当前 ptr 所在存储区引用数:" << *m_use << endl; 17 } 18 19 template<class T> 20 smart_ptr<T>& smart_ptr<T>::operator=(smart_ptr &rhs){ 21 ++ rhs.m_use; //此处需要防止自我复制 22 --(*m_use); 23 m_ptr = rhs.m_ptr; 24 m_use = rhs.m_use; 25 cout <<"赋值操作符:"<<endl; 26 cout << "当前 ptr 所在存储区引用数:" << *m_use << endl; 27 return *this; 28 } 29 30 template<class T> 31 void smart_ptr<T>::destroy(){ 32 if((--(*m_use))==0){ 33 delete m_ptr; 34 delete m_use; 35 cout <<"析构函数:"<<endl; 36 cout << "当前 ptr 所在存储区引用数:0"<< endl; 37 }else{ 38 cout <<"析构函数:"<<endl; 39 cout << "当前 ptr 所在存储区引用数:" << *m_use << endl; 40 } 41 } 42 43 template<class T> 44 smart_ptr<T>::~smart_ptr(){ 45 destroy(); 46 } 47 48 template<class T> 49 T* smart_ptr<T>::operator->(){ 50 return m_ptr; 51 } 52 53 template<class T> 54 T& smart_ptr<T>::operator*(){ 55 return (*m_ptr); 56 }
1 #include <iostream> 2 #include <string> 3 #include "smart_ptr.h" 4 5 using namespace std; 6 7 int main(int argc,char *argv[]){ 8 string *ptr = new string("这是一个字符串哇"); 9 { 10 smart_ptr<string> sp(ptr); 11 smart_ptr<string> sp1 = sp; 12 smart_ptr<string> sp2 = sp; 13 sp2 = sp1; 14 } 15 std::cin.get(); 16 return 0; 17 }
上述的测试程序将输出如下信息,达到了我们要的效果!