由一维数组表示的N维数组实现(C++)
工作中,经常需要表示多维数组(如二维矩阵),常见的做法是使用T **pArr;
T **pArr = new T*[M];//创建二维数组[M][N] for (int i=0;i<M;i++) { pArr[i] = new T[N]; }
销毁内存:
for (int i=0;i<M;i++) { delete[] pArr[i]; } delete[] pArr;
若是三维数组,需要创建三次,T*** pArr;以此类推,操作繁琐。
为方便动态生成多维数组,本文使用一维数据表示多维数组,并基于C++模板和运算符重载,使用一维数组动态表示多维数组,支持数据切片,简化多维数组的处理和使用。
#ifndef MYNDARRAY_H_20170226 #define MYNDARRAY_H_20170226 #include <assert.h> #include <iostream> using namespace std; //on dim array represent n-dim array typedef unsigned int iNum;//使用unsigned int作为后续整数型数据类型 typedef struct stMyRange{//闭区间[start,end] iNum Start;//start 和end 从 0开始 iNum End;// stMyRange(){ Start = End = 0; } stMyRange(iNum start,iNum end){ assert(start >= 0 && end >= start); Start = start; End = end; } }stMyRange; typedef struct stMyRank{//rank代表一个数据维度 iNum Rank;//数据维度的长度 iNum Start;//数据维度开始位置,start from 0 iNum End;//数据维度结束位置,start from 0 stMyRank(){ Rank = 0;//if dim == 0, this rank is not exist Start = 0;//if start ==0 && end == Rank-1, reprsent the whole rank; End = 0; } void SetRank(iNum rank){ assert(rank > 0); Rank = rank; Start = 0; End = Rank - 1; } void SetRange(stMyRange rang){ assert(rang.Start >= 0 && rang.End >= rang.Start && rang.End < Rank); Start = rang.Start; End = rang.End; } }stMyRank;
//目前设置最大数据维度为4,后续根据需要可扩展 #define MyMaxRank 4 template<class T> class MyNDArray{//C++模板使用,可表示多种类型的数据 protected: T* m_pData;//row based storage,行优先存储 stMyRank m_Rank[MyMaxRank];//记录数据维度 iNum m_iRankCap;//total rank capacity;数据维度个数 bool m_isSlice;//if slice==true, it will reuse the m_pData of its parent;//是否是数据切片 public: MyNDArray(){ m_pData = NULL; m_iRankCap = 0; m_isSlice = false; } MyNDArray(iNum rank1){ m_Rank[0].SetRank(rank1); m_iRankCap = 1; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_iRankCap = 2; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2,iNum rank3){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_Rank[2].SetRank(rank3); m_iRankCap = 3; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2,iNum rank3,iNum rank4){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_Rank[2].SetRank(rank3); m_Rank[3].SetRank(rank4); m_iRankCap = 4; m_isSlice = false; Init(); } //创建数据切片,新的ndarray共享partent的pData数据域 MyNDArray(MyNDArray& parent,stMyRank rank1){ m_Rank[0]=rank1; m_iRankCap = 1; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_iRankCap = 2; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2,stMyRank rank3){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_Rank[2]=rank3; m_iRankCap = 3; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2,stMyRank rank3,stMyRank rank4){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_Rank[2]=rank3; m_Rank[3]=rank4; m_iRankCap = 4; m_isSlice = true; m_pData = parent.GetData(); } ~MyNDArray(){ Clear(); } void Clear(){ if(!m_isSlice && m_pData != NULL){ delete[] m_pData; m_pData = NULL; } } T* GetData(){ return m_pData; } void SetData(T* pData){ m_pData = pData; } iNum RankCap(){ return m_iRankCap; } iNum RankSize(iNum rank){//start from 1,获取第i维的维度长度 if(rank < 1) return 0; return m_Rank[rank-1].End - m_Rank[rank-1].Start + 1; } MyNDArray Slice(stMyRange rang1){//range start from 0,数据切片 stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); MyNDArray res(*this,rank1); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); MyNDArray res(*this,rank1,rank2); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2,stMyRange rang3){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); stMyRank rank3 = m_Rank[2]; rank3.SetRange(rang3); MyNDArray res(*this,rank1,rank2,rank3); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2,stMyRange rang3,stMyRange rang4){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); stMyRank rank3 = m_Rank[2]; rank3.SetRange(rang3); stMyRank rank4 = m_Rank[3]; rank4.SetRange(rang4); MyNDArray res(*this,rank1,rank2,rank3,rank4); return res; } //override operator(), 使用MyNDarray arr; arr(1,2,3,4) T& operator()(iNum r1=0,iNum r2=0,iNum r3=0,iNum r4=0){ assert(m_pData != NULL); return m_pData[StoreIndex(r1,r2,r3,r4)]; } const T& operator()(iNum r1=0,iNum r2=0,iNum r3=0,iNum r4=0) const{ assert(m_pData != NULL); return m_pData[StoreIndex(r1,r2,r3,r4)]; } protected: iNum StoreIndex(iNum r1,iNum r2,iNum r3,iNum r4){//数据存储引擎,以行优先存储 assert(m_iRankCap > 0); if(m_isSlice){ r1 += m_Rank[0].Start; r2 += m_Rank[1].Start; r3 += m_Rank[2].Start; r4 += m_Rank[3].Start; } //判断数据是否在合法的范围内 if(m_iRankCap >= 1) assert(r1 >= m_Rank[0].Start && r1 <= m_Rank[0].End); if(m_iRankCap >= 2) assert(r2 >= m_Rank[1].Start && r2 <= m_Rank[1].End); if(m_iRankCap >= 3) assert(r3 >= m_Rank[2].Start && r3 <= m_Rank[2].End); if(m_iRankCap >= 4) assert(r4 >= m_Rank[3].Start && r4 <= m_Rank[3].End); iNum index = 0; index = r1 + m_Rank[0].Rank * ( r2 + m_Rank[1].Rank * ( r3 + m_Rank[2].Rank * r4 ));//row based storage return index; } void Init(){//初始化内存 iNum size = 1; for (int i=0;i<m_iRankCap;i++) { size *= m_Rank[i].Rank; } if(m_iRankCap > 0 && size > 0) { m_pData = new T[size]; } else m_pData = NULL; } };
//重载输入>>和输出<<,用于数据读取和写入 //override operator >> template<class T> istream& operator >> (istream& myin,MyNDArray<T>& arr){ iNum r1,r2,r3,r4; r1 = arr.RankSize(1); r2 = arr.RankSize(2); r3 = arr.RankSize(3); r4 = arr.RankSize(4); iNum rankCap = arr.RankCap(); if(rankCap < 2) r2 = 1; if(rankCap < 3) r3 = 1; if(rankCap < 4) r4 = 1; for (int i4=0;i4<r4;i4++) { for (int i3=0;i3<r3;i3++) { for (int i2=0;i2<r2;i2++) { for (int i1=0;i1<r1;i1++) { myin >> arr(i1,i2,i3,i4);//调用重载的operator(),读取数据 } } } } return myin; } //override operator << template<class T> ostream& operator << (ostream& myout,MyNDArray<T>& arr){ iNum r1,r2,r3,r4; r1 = arr.RankSize(1); r2 = arr.RankSize(2); r3 = arr.RankSize(3); r4 = arr.RankSize(4); iNum rankCap = arr.RankCap(); if(rankCap < 2) r2 = 1; if(rankCap < 3) r3 = 1; if(rankCap < 4) r4 = 1; for (int i4=0;i4<r4;i4++) { myout<<"rank4:"<<i4+1<<endl; for (int i3=0;i3<r3;i3++) { myout<<"rank3:"<<i3+1<<endl; for (int i2=0;i2<r2;i2++) { for (int i1=0;i1<r1;i1++) { myout <<arr(i1,i2,i3,i4)<<"\t"; } myout<<"\n"; } } myout<<"\n"; } return myout; } #endif
使用方法:
int main(int argc, char* argv[]) { MyNDArray<int> ndarr(3,3,2,2); ifstream fin("ndarr.txt",ios::in); fin>>ndarr; cout<<ndarr; stMyRange range1(1,2); stMyRange range2(1,2); stMyRange range3(1,1); stMyRange range4(1,1); MyNDArray<int> slice = ndarr.Slice(range1,range2,range3,range4); cout<<"slice:"<<endl<<slice<<endl; cin>>ndarr; return 0; }
ndarr.txt内如如下:
1 2 3 4 5 6 7 8 9 11 12 13 14 15 16 17 18 19 21 22 23 24 25 26 27 28 29 31 32 33 34 35 36 37 38 39
上述测试程序的输出结果为:
35 36
38 39
满足数据切片要求。
参考资料:
C++运算符重载:http://www.cnblogs.com/lfsblack/archive/2012/10/01/2709476.html
python.ndarray简单使用
C++ prime中的模板章节