机器学习猪

--------一只帅气潇洒略带才气的猪

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

1 前言

    之前在写影像融合算法的时候,免不了要实现将多光谱影像重采样到全色大小。当时为了不影响融合算法整体开发进度,其中重采样功能用的是GDAL开源库中的Warp接口实现的。

后来发现GDAL Warp接口实现的多光谱到全色影像的重采样主要存在两个问题:1 与原有平台的已有功能不兼容,产生冲突;2 效率较低。因此,决定重新设计和开发一个这样的功能,方便后期软件系统的维护等。

 

2 图像重采样

图像处理从形式上来说主要包括两个方面:1 单像素或者邻域像素的处理,比如影像的相加或者滤波运算等;2 图像几何空间变换,如图像的重采样,配准等。

影像重采样的几何空间变换公式如下:

 

其中 为变换系数,常用的重采样算法主要包括以下三种:1 最近邻;2 双线性;3 三次卷积。

2.1 最近邻采样

最近邻采样的原理概况起来就是用采样点位置最近的一个像素值替代采样点位置的像素值。在这里插入一点:

通常图像空间变换有两种方法,直接法或者间接法。以图像重采样为例说明如下:直接法:从原始的图像行列初始值开始,根据变换公式,计算采样后的像素位置,并对位置赋值,但是这种方法会出现,原始图像的多个像素点对应到同一采样后的像素点,从而还要增加额外方法进行处理;间接法:是从重采样后图像的行列初始值开始,计算得到其在原始影像中的位置,并根据一定的算法进行计算,得到采样后的值。这种方法简单直接,本文就是采用这样的方法。

最近邻采样的实现算法如下:

首先遍历采样点,根据公式计算采样点在原始图像中的位置,假设位置为 。然后计算与 最近的点,并将其像素值赋为采样点的像素值。其公式如下:

2.2 双线性

双线性采样和最近邻赋值不同,是找到 最近的四个像素点,然后将距离作为权重分别插值 和 方向,从而插值到采样点位置。具体公式见代码部分。

2.3 三次卷积

三次卷积是利用 最近的16个像素点进行插值计算得到。同样也是分别插值 和 方向。具体公式见代码部分。

 

3 实验结果

主要实现了两个版本的差值结果:1 CPU版本;2 GPU版本(初学)。考虑到多光谱和全色影像范围大小不一致的事实,算法首先计算全色和多光谱的重叠区域。话不多说,先看看详细的代码实现过程。

CPU版本:

  1 #ifndef RESAMPLECPU_H
  2 #define RESAMPLECPU_H
  3 
  4 #include <gdal_alg_priv.h>
  5 #include <gdal_priv.h>
  6 #include <assert.h>
  7 
  8 
  9 template<typename T>
 10 void ReSampleCPUKernel(const float *mssData,
 11                        T *resampleData,
 12                        int mssWidth,
 13                        int mssHeight,
 14                        int mssBandCount,
 15                        int mssOffsetX,
 16                        int mssOffsetY,
 17                        int panWidth,
 18                        int panHeight,
 19                        float radioX,
 20                        float radioY,
 21                        double dfDstNoDataValue,
 22                        int MethodType)
 23 {
 24     assert(mssData != NULL);
 25     float eps = 0.00001f;
 26     for(int idx = 0;idx < panHeight;idx++){
 27         for(int idy = 0;idy<panWidth;idy++){
 28             // 找到对应的MSS像素位置
 29             float curX = (float)idx * radioX;
 30             float curY = (float)idy * radioY;
 31             if(mssData[int(curX)*mssWidth*mssBandCount + int(curY)] == dfDstNoDataValue)
 32             {
 33                 int i = 0;
 34                 while(i < mssBandCount){
 35                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(dfDstNoDataValue);
 36                     i++;
 37                 }
 38                 continue;
 39             }
 40             if(MethodType == 0){  // 最近邻
 41                 int nearX = (int)(curX + 0.5)>(int)curX?(int)(curX + 1):(int)curX;
 42                 int nearY = (int)(curY + 0.5)>(int)curY?(int)(curY + 1):(int)curY;
 43                 if(nearX >= mssHeight - 1){
 44                     nearX = mssHeight - 1;
 45                 }
 46                 if(nearY >= mssWidth - 1){
 47                     nearY = mssWidth - 1;
 48                 }
 49                 if(nearX < mssHeight && nearY < mssWidth){
 50                     int i = 0;
 51                     while(i < mssBandCount){
 52                         float tmp = mssData[nearX*mssWidth*mssBandCount + i*mssWidth + nearY];
 53                         resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
 54                         i++;
 55                     }
 56                 }
 57             }
 58 
 59             if(MethodType == 1){   // 双线性
 60                 float dataX = curX - (int)curX;
 61                 float dataY = curY - (int)curY;
 62                 int preX = (int)curX;
 63                 int preY = (int)curY;
 64                 int postX = (int)curX + 1;
 65                 int postY = (int)curY + 1;
 66                 if(postX >= mssHeight - 1){
 67                     postX = mssHeight - 1;
 68                 }
 69                 if(postY >= mssWidth - 1){
 70                     postY = mssWidth - 1;
 71                 }
 72 
 73                 float Wx1 = 1 - dataX;
 74                 float Wx2 = dataX;
 75                 float Wy1 = 1 - dataY;
 76                 float Wy2 = dataY;
 77                 // 双线性差值核心代码
 78                 int i = 0;
 79                 while(i < mssBandCount){
 80                     float pMssValue[4] = {0,0,0,0};
 81                     pMssValue[0] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + preY];
 82                     pMssValue[1] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + postY];
 83                     pMssValue[2] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + preY];
 84                     pMssValue[3] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + postY];
 85                     float tmp = Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[2]) + Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[3]);
 86                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
 87                     i++;
 88                 }
 89             }
 90 
 91 
 92             if(MethodType == 2){  // 双三次卷积
 93                 float dataX = curX - (int)curX;
 94                 float dataY = curY - (int)curY;
 95                 int preX1 = (int)curX - 1;
 96                 int preX2 = (int)curX;
 97                 int postX1 = (int)curX + 1;
 98                 int postX2 = (int)curX + 2;
 99                 int preY1 = (int)curY - 1;
100                 int preY2 = (int)curY;
101                 int postY1 = (int)curY + 1;
102                 int postY2 = (int)curY + 2;
103                 if(preX1 < 0) preX1 = 0;
104                 if(preY1 < 0) preY1 = 0;
105                 if(postX1 > mssHeight - 1) postX1 = mssHeight - 1;
106                 if(postX2 > mssHeight - 1) postX2 = mssHeight - 1;
107                 if(postY1 > mssWidth - 1) postY1 = mssWidth - 1;
108                 if(postY2 > mssWidth - 1) postY2 = mssWidth - 1;
109 
110                 float Wx1 = -1.0f*dataX + 2*dataX*dataX - dataX*dataX*dataX;
111                 float Wx2 = 1 - 2*dataX*dataX + dataX*dataX*dataX;
112                 float Wx3 = dataX + dataX*dataX - dataX*dataX*dataX;
113                 float Wx4 = -1.0f*dataX*dataX + dataX*dataX*dataX;
114                 float Wy1 = -1.0f*dataY + 2*dataY*dataY - dataY*dataY*dataY;
115                 float Wy2 = 1 - 2*dataY*dataY + dataY*dataY*dataY;
116                 float Wy3 = dataY + dataY*dataY - dataY*dataY*dataY;
117                 float Wy4 = -1.0f*dataY*dataY + dataY*dataY*dataY;
118                 
119                 int i = 0;
120                 while(i < mssBandCount){
121                     float *pMssValue = new float[16];
122                     memset(pMssValue,0,sizeof(float)*16);
123                     pMssValue[0] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY1];
124                     pMssValue[1] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY2];
125                     pMssValue[2] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY1];
126                     pMssValue[3] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY2];
127 
128                     pMssValue[4] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY1];
129                     pMssValue[5] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY2];
130                     pMssValue[6] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY1];
131                     pMssValue[7] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY2];
132 
133                     pMssValue[8] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY1];
134                     pMssValue[9] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY2];
135                     pMssValue[10] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY1];
136                     pMssValue[11] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY2];
137 
138                     pMssValue[12] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY1];
139                     pMssValue[13] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY2];
140                     pMssValue[14] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY1];
141                     pMssValue[15] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY2];
142 
143                     float tmp = Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[4] + Wx3*pMssValue[8] + Wx4*pMssValue[12])+    
144                         Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[5] + Wx3*pMssValue[9] + Wx4*pMssValue[13])+
145                         Wy3*(Wx1*pMssValue[2] + Wx2*pMssValue[6] + Wx3*pMssValue[10] + Wx4*pMssValue[14])+
146                         Wy4*(Wx1*pMssValue[3] + Wx2*pMssValue[7] + Wx3*pMssValue[11] + Wx4*pMssValue[15]);
147                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
148                     delete []pMssValue;pMssValue = NULL;
149                     i++;
150                 }    
151             }
152         }
153     }
154 }
155 
156 int ReSampleCPUApp(const char *mssfileName,
157                       const char *panfileName,
158                       const char *resamplefileName,
159                       int MethodType = 1)
160 {
161     GDALAllRegister();
162     GDALDataset *poPANDS = (GDALDataset*)GDALOpen(panfileName,GA_ReadOnly);
163     GDALDataset *poMSSDS = (GDALDataset*)GDALOpen(mssfileName,GA_ReadOnly);
164     if(!poPANDS || !poMSSDS)
165         return -1;
166 
167     //MSS info
168     int mssBandCount = poMSSDS->GetRasterCount();
169     int mssWidth = poMSSDS->GetRasterXSize();
170     int mssHeight = poMSSDS->GetRasterYSize();
171     double adfMssGeoTransform[6] = {0};
172     poMSSDS->GetGeoTransform(adfMssGeoTransform);
173     GDALDataType mssDT = poMSSDS->GetRasterBand(1)->GetRasterDataType();
174 
175     int bSrcHasNoData;
176     double dfSrcNoDataValue = 0;
177     dfSrcNoDataValue = GDALGetRasterNoDataValue(poMSSDS->GetRasterBand(1),&bSrcHasNoData);
178     if(!bSrcHasNoData) dfSrcNoDataValue = 0.0;
179     //DT = mssDT;
180 
181     int *pBandMap= new int[mssBandCount];
182     for(int i = 0;i<mssBandCount;i++){
183         pBandMap[i] = i+1;
184     }
185 
186     // PAN Info
187     int panBandCount = poPANDS->GetRasterCount();
188     int panWidth = poPANDS->GetRasterXSize();
189     int panHeidht = poPANDS->GetRasterYSize();
190     double adfPanGeoTransform[6] = {0};
191     poPANDS->GetGeoTransform(adfPanGeoTransform);
192     GDALDataType panDT = poPANDS->GetRasterBand(1)->GetRasterDataType();
193 
194     // 创建新数据集=======投影信息
195     double adfResampleGeoTransform[6] = {0};
196     adfResampleGeoTransform[1] = adfPanGeoTransform[1];
197     adfResampleGeoTransform[5] = adfPanGeoTransform[5];
198     adfResampleGeoTransform[2] = adfPanGeoTransform[2];
199     adfResampleGeoTransform[4] = adfPanGeoTransform[4];
200     if(adfMssGeoTransform[0] >= adfPanGeoTransform[0]){
201         adfResampleGeoTransform[0] = adfMssGeoTransform[0];
202     }else{
203         adfResampleGeoTransform[0] = adfPanGeoTransform[0];
204     }
205     if(adfMssGeoTransform[3] > adfPanGeoTransform[3]){
206         adfResampleGeoTransform[3] = adfPanGeoTransform[3];
207     }else{
208         adfResampleGeoTransform[3] = adfMssGeoTransform[3];
209     }
210 
211     // 创建新数据集=======影像大小
212     double panEndX = adfPanGeoTransform[0] + panWidth*adfPanGeoTransform[1] + 
213         panHeidht*adfPanGeoTransform[2];
214     double panEndY = adfPanGeoTransform[3] + panHeidht*adfPanGeoTransform[4] + 
215         panHeidht*adfPanGeoTransform[5];
216 
217     double mssEndX = adfMssGeoTransform[0] +mssWidth*adfMssGeoTransform[1] + 
218         mssHeight*adfMssGeoTransform[2];
219     double mssEndY = adfMssGeoTransform[3] + mssWidth*adfMssGeoTransform[4] + 
220         mssHeight*adfMssGeoTransform[5];
221     double resampleEndXY[2] = {0};
222     if(panEndX > mssEndX)
223         resampleEndXY[0] = mssEndX;
224     else
225         resampleEndXY[0] = panEndX;
226     if(panEndY >= mssEndY)
227         resampleEndXY[1] = panEndY;
228     else
229         resampleEndXY[1] = mssEndY;
230 
231     // 创建新数据集=======MSS AND PAN 有效长宽
232     int resampleWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfResampleGeoTransform[1] + 0.5);
233     int resampleHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfResampleGeoTransform[5] + 0.5);
234     int mssEffectiveWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
235     int mssEffectiveHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
236     int panEffectiveWidth = resampleWidth;
237     int panEffectiveHeight = resampleHeight;
238 
239     // 创建新数据集=======位置增益大小
240     int mssGainX = static_cast<int>((adfResampleGeoTransform[0] - adfMssGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
241     int mssGainY = static_cast<int>((adfResampleGeoTransform[3] - adfMssGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
242     int panGainX = static_cast<int>((adfResampleGeoTransform[0] - adfPanGeoTransform[0])/adfPanGeoTransform[1] + 0.5);
243     int panGainY = static_cast<int>((adfResampleGeoTransform[3] - adfPanGeoTransform[3])/adfPanGeoTransform[5] + 0.5);
244 
245 
246     // 创建新数据集=======创建文件
247     GDALDriver *poOutDriver = (GDALDriver*)GDALGetDriverByName("GTIFF");
248     if(!poOutDriver){
249         return -1;
250     }
251     GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
252         resampleHeight,mssBandCount,mssDT,NULL);
253     poOutDS->SetGeoTransform(adfResampleGeoTransform);
254     poOutDS->SetProjection(poPANDS->GetProjectionRef());
255 
256 
257     // 重采样核心代码============图像分块
258     int iNumRow = 256;
259     if(iNumRow > resampleHeight){
260         iNumRow = 1;
261     }
262     int loopNum = (resampleHeight + iNumRow - 1)/iNumRow;  //分块数
263     int nLineSpace,nPixSpace,nBandSpace;
264     nLineSpace = sizeof(float)*mssEffectiveWidth*mssBandCount;
265     nPixSpace = 0;
266     nBandSpace = sizeof(float)*mssEffectiveWidth;
267 
268     // 重采样采样比例
269     float radioX = float(adfPanGeoTransform[1]/adfMssGeoTransform[1]);
270     float radioY = float(adfPanGeoTransform[5]/adfMssGeoTransform[5]);
271 
272     int mssCurPosX = mssGainX;
273     int mssCurPosY = mssGainY;
274     int mssCurWidth = 0;
275     int mssCurHeight = 0;
276 
277     // 重采样核心代码============
278     for(int i = 0;i<loopNum;i++){
279         int tmpRowNum = iNumRow;
280         int startR = i*iNumRow;
281         int endR = startR + iNumRow - 1;
282         if(endR>resampleHeight -1){
283             tmpRowNum = resampleHeight - startR;
284             //endR = startR + tmpRowNum - 1;
285         }
286         //计算读取的MSS影像区域大小
287         int mssCurWidth = mssEffectiveWidth;
288         int mssCurHeight = 0;
289         if(MethodType == 0)
290             mssCurHeight = int(tmpRowNum*radioY);
291         else if(MethodType == 1)
292             mssCurHeight = int(tmpRowNum*radioY)+1;
293         else if(MethodType == 2)
294             mssCurHeight = int(tmpRowNum*radioY)+2;
295 
296         if(mssCurHeight + mssCurPosY > mssHeight - 1){
297             mssCurHeight = mssHeight - mssCurPosY;
298         }
299 
300         //创建数据
301         /*float *resampleBuf = (float *)malloc(sizeof(cl_float)*tmpRowNum*resampleWidth*trueBandCount);*/
302         float *mssBuf = (float *)malloc(sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
303         //memset(resampleBuf,0,sizeof(float)*tmpRowNum*resampleWidth*trueBandCount);
304         memset(mssBuf,0,sizeof(float)*mssCurHeight*mssCurWidth*mssBandCount);
305 
306         // 读取数据
307         poMSSDS->RasterIO(GF_Read,mssCurPosX,mssCurPosY,mssCurWidth,mssCurHeight,
308             mssBuf,mssCurWidth,mssCurHeight,GDT_Float32,mssBandCount,NULL,nPixSpace,
309             nLineSpace,nBandSpace);
310 
311         if(MethodType == 0)
312             mssCurPosY += mssCurHeight;
313         else if(MethodType == 1)
314             mssCurPosY += mssCurHeight - 1;        
315         else if(MethodType == 2)
316             mssCurPosY += mssCurHeight - 2;
317 
318         // 数据格式转换
319         long sz = tmpRowNum*resampleWidth*mssBandCount;
320         void *resampleBuf = NULL;
321         switch(mssDT){
322             case GDT_Byte:
323                 resampleBuf = new unsigned char[sz];
324                 ReSampleCPUKernel<unsigned char>(mssBuf,(unsigned char*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
325                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
326                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
327                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned char),
328                     resampleWidth*sizeof(unsigned char));
329                 break;
330             case GDT_UInt16:
331                 resampleBuf = new unsigned short int[sz];
332                 ReSampleCPUKernel<unsigned short int>(mssBuf,(unsigned short int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
333                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
334                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
335                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned short int),
336                     resampleWidth*sizeof(unsigned short int));
337                 break;
338             case GDT_Int16:
339                 resampleBuf = new short int[sz];
340                 ReSampleCPUKernel<short int>(mssBuf,(short int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
341                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
342                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
343                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(short int),
344                     resampleWidth*sizeof(short int));
345                 break;
346             case GDT_UInt32:
347                 resampleBuf = new unsigned int[sz];
348                 ReSampleCPUKernel<unsigned int>(mssBuf,(unsigned int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
349                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
350                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
351                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned int),
352                     resampleWidth*sizeof(unsigned int));
353                 break;
354             case GDT_Int32:
355                 resampleBuf = new int[sz];
356                 ReSampleCPUKernel<int>(mssBuf,(int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
357                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
358                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
359                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(int),
360                     resampleWidth*sizeof(int));
361                 break;
362             case GDT_Float32:
363                 resampleBuf = new float[sz];
364                 ReSampleCPUKernel<float>(mssBuf,(float*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
365                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
366                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
367                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(float),
368                     resampleWidth*sizeof(float));
369                 break;
370             case GDT_Float64:
371                 resampleBuf = new double[sz];
372                 ReSampleCPUKernel<double>(mssBuf,(double*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
373                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
374                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
375                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(double),
376                     resampleWidth*sizeof(double));
377                 break;
378         }
379         delete []mssBuf;
380         delete []resampleBuf;
381         std::cout<<i<<std::endl;
382     }
383     delete []pBandMap;pBandMap = NULL;
384     GDALClose((GDALDatasetH)poPANDS);
385     GDALClose((GDALDatasetH)poMSSDS);
386     GDALClose((GDALDatasetH)poOutDS);
387     return 0;
388 }
389 
390 #endif

GPU版本:

  1 #ifndef RESAMPLEOPENCL_H
  2 #define RESAMPLEOPENCL_H
  3 
  4 #include <CL/cl.h>
  5 #include <gdal_alg_priv.h>
  6 #include <gdal_priv.h>
  7 
  8 #pragma comment(lib,"OpenCL.lib")
  9 
 10 /*
 11 @ 功能描述
 12     读取源程序,将文本源程序读到内核中
 13 */
 14 char* LoadProgSource(const char* cFilename, const char* cPreamble, size_t* szFinalLength)
 15 {
 16     FILE* pFileStream = NULL;
 17     size_t szSourceLength;
 18 
 19     // open the OpenCL source code file
 20     pFileStream = fopen(cFilename, "rb");
 21     if(pFileStream == 0) 
 22     {     
 23         return NULL;
 24     }
 25 
 26     size_t szPreambleLength = strlen(cPreamble);
 27 
 28     // get the length of the source code
 29     fseek(pFileStream, 0, SEEK_END); 
 30     szSourceLength = ftell(pFileStream);
 31     fseek(pFileStream, 0, SEEK_SET); 
 32 
 33     // allocate a buffer for the source code string and read it in
 34     char* cSourceString = (char *)malloc(szSourceLength + szPreambleLength + 1); 
 35     memcpy(cSourceString, cPreamble, szPreambleLength);
 36     if (fread((cSourceString) + szPreambleLength, szSourceLength, 1, pFileStream) != 1)
 37     {
 38         fclose(pFileStream);
 39         free(cSourceString);
 40         return 0;
 41     }
 42 
 43     // close the file and return the total length of the combined (preamble + source) string
 44     fclose(pFileStream);
 45     if(szFinalLength != 0)
 46     {
 47         *szFinalLength = szSourceLength + szPreambleLength;
 48     }
 49     cSourceString[szSourceLength + szPreambleLength] = '\0';
 50 
 51     return cSourceString;
 52 }
 53 
 54 template<typename T>
 55 bool DataTypeTrans(const float *pSrcBuf,T *pDesBuf,long size)
 56 {
 57     if(pSrcBuf == NULL){
 58         return false;
 59     }
 60     while(size--){
 61         pDesBuf[size] = T(pSrcBuf[size]);
 62     }
 63     return true;
 64 }
 65 
 66 int ReSampleOpenCLApp(const char *mssfileName,
 67                       const char *panfileName,
 68                       const char *resamplefileName,
 69                       int MethodType = 1)
 70 {
 71     GDALAllRegister();
 72     GDALDataset *poPANDS = (GDALDataset*)GDALOpen(panfileName,GA_ReadOnly);
 73     GDALDataset *poMSSDS = (GDALDataset*)GDALOpen(mssfileName,GA_ReadOnly);
 74     if(!poPANDS || !poMSSDS)
 75         return -1;
 76 
 77     //MSS info
 78     int mssBandCount = poMSSDS->GetRasterCount();
 79     int mssWidth = poMSSDS->GetRasterXSize();
 80     int mssHeight = poMSSDS->GetRasterYSize();
 81     double adfMssGeoTransform[6] = {0};
 82     poMSSDS->GetGeoTransform(adfMssGeoTransform);
 83     GDALDataType mssDT = poMSSDS->GetRasterBand(1)->GetRasterDataType();
 84 
 85     int bSrcHasNoData;
 86     float dfSrcNoDataValue = 0;
 87     dfSrcNoDataValue = (float)GDALGetRasterNoDataValue(poMSSDS->GetRasterBand(1),&bSrcHasNoData);
 88     if(!bSrcHasNoData) dfSrcNoDataValue = 0.0;
 89 
 90 
 91     // PAN Info
 92     int panBandCount = poPANDS->GetRasterCount();
 93     int panWidth = poPANDS->GetRasterXSize();
 94     int panHeidht = poPANDS->GetRasterYSize();
 95     double adfPanGeoTransform[6] = {0};
 96     poPANDS->GetGeoTransform(adfPanGeoTransform);
 97     GDALDataType panDT = poPANDS->GetRasterBand(1)->GetRasterDataType();
 98 
 99     // 创建新数据集=======投影信息
100     double adfResampleGeoTransform[6] = {0};
101     adfResampleGeoTransform[1] = adfPanGeoTransform[1];
102     adfResampleGeoTransform[5] = adfPanGeoTransform[5];
103     adfResampleGeoTransform[2] = adfPanGeoTransform[2];
104     adfResampleGeoTransform[4] = adfPanGeoTransform[4];
105     if(adfMssGeoTransform[0] >= adfPanGeoTransform[0]){
106         adfResampleGeoTransform[0] = adfMssGeoTransform[0];
107     }else{
108         adfResampleGeoTransform[0] = adfPanGeoTransform[0];
109     }
110     if(adfMssGeoTransform[3] > adfPanGeoTransform[3]){
111         adfResampleGeoTransform[3] = adfPanGeoTransform[3];
112     }else{
113         adfResampleGeoTransform[3] = adfMssGeoTransform[3];
114     }
115 
116     // 创建新数据集=======影像大小
117     double panEndX = adfPanGeoTransform[0] + panWidth*adfPanGeoTransform[1] + 
118         panHeidht*adfPanGeoTransform[2];
119     double panEndY = adfPanGeoTransform[3] + panHeidht*adfPanGeoTransform[4] + 
120         panHeidht*adfPanGeoTransform[5];
121 
122     double mssEndX = adfMssGeoTransform[0] +mssWidth*adfMssGeoTransform[1] + 
123         mssHeight*adfMssGeoTransform[2];
124     double mssEndY = adfMssGeoTransform[3] + mssWidth*adfMssGeoTransform[4] + 
125         mssHeight*adfMssGeoTransform[5];
126     double resampleEndXY[2] = {0};
127     if(panEndX > mssEndX)
128         resampleEndXY[0] = mssEndX;
129     else
130         resampleEndXY[0] = panEndX;
131     if(panEndY >= mssEndY)
132         resampleEndXY[1] = panEndY;
133     else
134         resampleEndXY[1] = mssEndY;
135 
136     // 创建新数据集=======MSS AND PAN 有效长宽
137     int resampleWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfResampleGeoTransform[1] + 0.5);
138     int resampleHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfResampleGeoTransform[5] + 0.5);
139     int mssEffectiveWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
140     int mssEffectiveHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
141     int panEffectiveWidth = resampleWidth;
142     int panEffectiveHeight = resampleHeight;
143 
144     // 创建新数据集=======位置增益大小
145     int mssGainX = static_cast<int>((adfResampleGeoTransform[0] - adfMssGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
146     int mssGainY = static_cast<int>((adfResampleGeoTransform[3] - adfMssGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
147     int panGainX = static_cast<int>((adfResampleGeoTransform[0] - adfPanGeoTransform[0])/adfPanGeoTransform[1] + 0.5);
148     int panGainY = static_cast<int>((adfResampleGeoTransform[3] - adfPanGeoTransform[3])/adfPanGeoTransform[5] + 0.5);
149 
150 
151     // 创建新数据集=======创建文件
152     GDALDriver *poOutDriver = (GDALDriver*)GDALGetDriverByName("GTIFF");
153     if(!poOutDriver){
154         return -1;
155     }
156     GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
157         resampleHeight,mssBandCount,mssDT,NULL);
158     //GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
159     //    resampleHeight,mssBandCount,GDT_Float32,NULL);
160     poOutDS->SetGeoTransform(adfResampleGeoTransform);
161     poOutDS->SetProjection(poPANDS->GetProjectionRef());
162 
163     int pBandMap[4] = {1,2,3,4};
164     // 重采样核心代码============图像分块
165     int iNumRow = 256;
166     if(iNumRow > resampleHeight){
167         iNumRow = 1;
168     }
169     int loopNum = (resampleHeight + iNumRow - 1)/iNumRow;  //分块数
170     int nLineSpace,nPixSpace,nBandSpace;
171     nLineSpace = sizeof(float)*mssEffectiveWidth*mssBandCount;
172     nPixSpace = 0;
173     nBandSpace = sizeof(float)*mssEffectiveWidth;
174 
175     // 重采样采样比例
176     float radioX = adfPanGeoTransform[1]/adfMssGeoTransform[1];
177     float radioY = adfPanGeoTransform[5]/adfMssGeoTransform[5];
178 
179     int mssCurPosX = mssGainX;
180     int mssCurPosY = mssGainY;
181     int mssCurWidth = 0;
182     int mssCurHeight = 0;
183 
184     // 重采样核心代码============
185     // OpenCL部分 =============== 1 创建平台
186     cl_uint num_platforms;
187     cl_int ret = clGetPlatformIDs(0,NULL,&num_platforms);
188     if(ret != CL_SUCCESS || num_platforms < 1){
189         printf("clGetPlatformIDs Error\n");
190         return -1;
191     }
192     cl_platform_id platform_id = NULL;
193     ret = clGetPlatformIDs(1,&platform_id,NULL);
194     if(ret != CL_SUCCESS){
195         printf("clGetPlatformIDs Error2\n");
196         return -1;
197     }
198 
199     // OpenCL部分 =============== 2 获得设备
200     cl_uint num_devices;
201     ret = clGetDeviceIDs(platform_id,CL_DEVICE_TYPE_GPU,0,NULL,
202         &num_devices);
203     if(ret != CL_SUCCESS || num_devices < 1){
204         printf("clGetDeviceIDs Error\n");
205         return -1;
206     }
207     cl_device_id device_id;
208     ret = clGetDeviceIDs(platform_id,CL_DEVICE_TYPE_GPU,1,&device_id,NULL);
209     if(ret != CL_SUCCESS){
210         printf("clGetDeviceIDs Error2\n");
211         return -1;
212     }
213 
214     // OpenCL部分 =============== 3 创建Context
215     cl_context_properties props[] = {CL_CONTEXT_PLATFORM,
216         (cl_context_properties)platform_id,0};
217     cl_context context = NULL;
218     context = clCreateContext(props,1,&device_id,NULL,NULL,&ret);
219     if(ret != CL_SUCCESS || context == NULL){
220         printf("clCreateContext Error\n");
221         return -1;
222     }
223 
224     // OpenCL部分 =============== 4 创建Command Queue
225     cl_command_queue command_queue = NULL;
226     command_queue = clCreateCommandQueue(context,device_id,0,&ret);
227     if(ret != CL_SUCCESS || command_queue == NULL){
228         printf("clCreateCommandQueue Error\n");
229         return -1;
230     }
231 
232     // OpenCL部分 =============== 6 创建编译Program
233     const char *strfile = "D:\\PIE3\\src\\Test\\TextOpecCLResample\\TextOpecCLResample\\ReSampleKernel.txt";
234     size_t lenSource = 0;
235     char *kernelSource = LoadProgSource(strfile,"",&lenSource);
236     cl_program *programs = (cl_program *)malloc(loopNum*sizeof(cl_program));
237     memset(programs,0,sizeof(cl_program)*loopNum);
238 
239     cl_kernel *kernels = (cl_kernel*)malloc(loopNum*sizeof(cl_kernel));
240     memset(kernels,0,sizeof(cl_kernel)*loopNum);
241 
242 
243     for(int i = 0;i<loopNum;i++){
244         int tmpRowNum = iNumRow;
245         int startR = i*iNumRow;
246         int endR = startR + iNumRow - 1;
247         if(endR>resampleHeight -1){
248             tmpRowNum = resampleHeight - startR;
249             //endR = startR + tmpRowNum - 1;
250         }
251         //计算读取的MSS影像区域大小
252         int mssCurWidth = mssEffectiveWidth;
253         int mssCurHeight = 0;
254         if(MethodType == 0)
255             mssCurHeight = int(tmpRowNum*radioY);
256         else if(MethodType == 1)
257             mssCurHeight = int(tmpRowNum*radioY)+1;
258         else if(MethodType == 2)
259             mssCurHeight = int(tmpRowNum*radioY)+2;
260 
261         if(mssCurHeight + mssCurPosY > mssHeight - 1){
262             mssCurHeight = mssHeight - mssCurPosY;
263         }
264 
265         //创建数据
266         float *resampleBuf = (float *)malloc(sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount);
267         float *mssBuf = (float *)malloc(sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
268         memset(resampleBuf,0,sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount);
269         memset(mssBuf,0,sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
270         
271         // 读取数据
272         poMSSDS->RasterIO(GF_Read,mssCurPosX,mssCurPosY,mssCurWidth,mssCurHeight,
273             mssBuf,mssCurWidth,mssCurHeight,GDT_Float32,mssBandCount,pBandMap,nPixSpace,
274             nLineSpace,nBandSpace);
275 
276         if(MethodType == 0)
277             mssCurPosY += mssCurHeight;
278         else if(MethodType == 1)
279             mssCurPosY += mssCurHeight - 1;        
280         else if(MethodType == 2)
281             mssCurPosY += mssCurHeight - 2;
282 
283         // OpenCL部分 =============== 5 创建Memory Object
284         cl_mem mem_mss = NULL;
285         mem_mss = clCreateBuffer(context,CL_MEM_READ_WRITE | CL_MEM_USE_HOST_PTR,
286             sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount,mssBuf,&ret);
287         if(ret != CL_SUCCESS || NULL == mem_mss){
288             printf("clCreateBuffer Error\n");
289             return -1;
290         }
291 
292         cl_mem mem_resample = NULL;
293         mem_resample = clCreateBuffer(context,CL_MEM_READ_WRITE | CL_MEM_USE_HOST_PTR,
294             sizeof(cl_float)*resampleWidth*tmpRowNum*mssBandCount,resampleBuf,&ret);
295         if(ret != CL_SUCCESS || NULL == mem_resample){
296             printf("clCreateBuffer Error\n");
297             return -1;
298         }
299 
300         // OpenCL部分 =============== 6 创建编译Program
301         //const char *strfile = "D:\\PIE3\\src\\Test\\TextOpecCLResample\\TextOpecCLResample\\ReSampleKernel.txt";
302         //size_t lenSource = 0;
303         //char *kernelSource = LoadProgSource(strfile,"",&lenSource);
304         //cl_program program = NULL;
305         programs[i] = clCreateProgramWithSource(context,1,(const char**)&kernelSource,
306             NULL,&ret);
307         if(ret != CL_SUCCESS || NULL == programs[i]){
308             printf("clCreateProgramWithSource Error\n");
309             return -1;
310         }
311         ret = clBuildProgram(programs[i],1,&device_id,NULL,NULL,NULL);
312         if(ret != CL_SUCCESS){
313             char* build_log;
314             size_t log_size;
315             //查询日志的大小
316             clGetProgramBuildInfo(programs[i], device_id, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
317             build_log = new char[log_size+1];
318             //获得编译日志信息
319             ret = clGetProgramBuildInfo(programs[i], device_id, CL_PROGRAM_BUILD_LOG, log_size, build_log, NULL);
320             build_log[log_size] = '\0';
321             printf("%s\n",build_log);
322             printf("编译失败!");
323             delete []build_log;
324             return -1;
325         }
326 
327         // OpenCL部分 =============== 7 创建Kernel
328         //cl_kernel kernel = NULL;
329         kernels[i] = clCreateKernel(programs[i],"ReSampleKernel",&ret);
330         if(ret != CL_SUCCESS || NULL == kernels[i]){
331             printf("clCreateProgramWithSource Error\n");
332             return -1;
333         }
334 
335         // OpenCL部分 =============== 8 设置Kernel参数
336         ret = clSetKernelArg(kernels[i],0,sizeof(cl_mem),&mem_mss);
337         ret |= clSetKernelArg(kernels[i],1,sizeof(cl_mem),&mem_resample);
338         ret |= clSetKernelArg(kernels[i],2,sizeof(cl_int),&mssCurWidth);
339         ret |= clSetKernelArg(kernels[i],3,sizeof(cl_int),&mssCurHeight);
340         ret |= clSetKernelArg(kernels[i],4,sizeof(cl_int),&mssBandCount);
341         ret |= clSetKernelArg(kernels[i],5,sizeof(cl_int),&mssGainX);
342         ret |= clSetKernelArg(kernels[i],6,sizeof(cl_int),&mssGainY);
343         ret |= clSetKernelArg(kernels[i],7,sizeof(cl_int),&resampleWidth);
344         ret |= clSetKernelArg(kernels[i],8,sizeof(cl_int),&tmpRowNum);
345         ret |= clSetKernelArg(kernels[i],9,sizeof(cl_float),&radioX);
346         ret |= clSetKernelArg(kernels[i],10,sizeof(cl_float),&radioY);
347         ret |= clSetKernelArg(kernels[i],11,sizeof(cl_float),&dfSrcNoDataValue);
348         ret |= clSetKernelArg(kernels[i],12,sizeof(cl_int),&MethodType);
349         if(ret != CL_SUCCESS){
350             printf("clSetKernelArg Error\n");
351             return -1;
352         }
353 
354         // OpenCL部分 =============== 9 设置Group Size
355         cl_uint work_dim = 2;
356         size_t global_work_size[] = {resampleWidth,tmpRowNum};
357         size_t *local_work_size = NULL;
358 
359         // OpenCL部分 =============== 10 执行内核
360         ret = clEnqueueNDRangeKernel(command_queue,kernels[i],work_dim,NULL,global_work_size,
361             local_work_size,0,NULL,NULL);
362         ret |= clFinish(command_queue);
363         if(ret != CL_SUCCESS){
364             printf("clEnqueueNDRangeKernel Error\n");
365             return -1;
366         }
367         
368         // OpenCL部分 =============== 11 读取结果
369         
370         resampleBuf = (float*)clEnqueueMapBuffer(command_queue,mem_resample,CL_TRUE,CL_MAP_READ | CL_MAP_WRITE,
371             0,sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount,0,NULL,NULL,&ret);
372         //ret = clEnqueueReadBuffer(command_queue,mem_resample,CL_TRUE,0,
373         //    sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount,(void*)resampleBuf,0,NULL,NULL);
374         if(ret != CL_SUCCESS){
375             printf("clEnqueueMapBuffer Error\n");
376             return -1;
377         }
378 
379         
380         // 数据格式转换
381         long sz = tmpRowNum*resampleWidth*mssBandCount;
382         void *pBuf = NULL;
383         CPLErr err;
384         switch(mssDT){
385             case GDT_Byte:
386                 pBuf = new unsigned char[sz];
387                 if(!DataTypeTrans<unsigned char>(resampleBuf,(unsigned char*)pBuf,sz))
388                 {
389                     printf("DataTypeTrans Error\n");
390                     return -1;
391                 }
392                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
393                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned char),
394                     resampleWidth*sizeof(unsigned char));
395                 break;
396             case GDT_UInt16:
397                 pBuf = new unsigned short int[sz];
398                 if(!DataTypeTrans<unsigned short int>(resampleBuf,(unsigned short int*)pBuf,sz))
399                 {
400                     printf("DataTypeTrans Error\n");
401                     return -1;
402                 }
403                 err = poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
404                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned short int),
405                     resampleWidth*sizeof(unsigned short int));
406                 break;
407             case GDT_Int16:
408                 pBuf = new short int[sz];
409                 if(!DataTypeTrans<short int>(resampleBuf,(short int*)pBuf,sz))
410                 {
411                     printf("DataTypeTrans Error\n");
412                     return -1;
413                 }
414                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
415                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(short int),
416                     resampleWidth*sizeof(short int));
417                 break;
418             case GDT_UInt32:
419                 pBuf = new unsigned int[sz];
420                 if(!DataTypeTrans<unsigned int>(resampleBuf,(unsigned int*)pBuf,sz))
421                 {
422                     printf("DataTypeTrans Error\n");
423                     return -1;
424                 }
425                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
426                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned int),
427                     resampleWidth*sizeof(unsigned int));
428                 break;
429             case GDT_Int32:
430                 pBuf = new int[sz];
431                 if(!DataTypeTrans<int>(resampleBuf,(int*)pBuf,sz))
432                 {
433                     printf("DataTypeTrans Error\n");
434                     return -1;
435                 }
436                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
437                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(int),
438                     resampleWidth*sizeof(int));
439                 break;
440             case GDT_Float32:
441                 pBuf = new float[sz];
442                 if(!DataTypeTrans<float>(resampleBuf,(float *)pBuf,sz))
443                 {
444                     printf("DataTypeTrans Error\n");
445                     return -1;
446                 }
447                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
448                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(float),
449                     resampleWidth*sizeof(float));
450                 break;
451             case GDT_Float64:
452                 pBuf = new double[sz];
453                 if(!DataTypeTrans<double>(resampleBuf,(double *)pBuf,sz))
454                 {
455                     printf("DataTypeTrans Error\n");
456                     return -1;
457                 }
458                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
459                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(double),
460                     resampleWidth*sizeof(double));
461                 break;
462         }
463         delete []pBuf;pBuf = NULL;
464         free(mssBuf);
465         free(resampleBuf);
466 
467         // OpenCL部分 =============== 12 释放资源
468         if(NULL != mem_mss) clReleaseMemObject(mem_mss);
469         if(NULL != mem_resample) clReleaseMemObject(mem_resample);
470         std::cout<<i<<std::endl;
471     }
472     // OpenCL部分 =============== 12 释放资源
473     int i = 0;
474     while(i < loopNum){
475         if(NULL != kernels[i]) clReleaseKernel(kernels[i]);
476         if(NULL != programs[i]) clReleaseProgram(programs[i]);
477         i++;
478     }
479 
480     if(NULL != command_queue) clReleaseCommandQueue(command_queue);
481     if(NULL != context) clReleaseContext(context);
482     GDALClose((GDALDatasetH)poPANDS);
483     GDALClose((GDALDatasetH)poMSSDS);
484     GDALClose((GDALDatasetH)poOutDS);
485     return 0;
486 }
487 
488 
489 
490 
491 
492 #endif

GPU核函数代码如下:

  1 #pragma OPENCL EXTENSION cl_amd_printf:enable
  2 
  3 __kernel void ReSampleKernel(__global const float *mssData,
  4                              __global float *resampleData,
  5                              int mssWidth,
  6                              int mssHeight,
  7                              int mssBandCount,
  8                              int mssOffsetX,
  9                              int mssOffsetY,
 10                              int panWidth,
 11                              int panHeight,
 12                              float radioX,
 13                              float radioY,
 14                              float dfDstNoDataValue,
 15                              int MethodType)
 16 {
 17     int idx = get_global_id(1);  // 采样行
 18     int idy = get_global_id(0);  // 采样列
 19     float eps = 0.00001f;
 20     if(idx < panHeight && idy < panWidth){
 21         // 找到对应的MSS像素位置
 22         float curX = (float)idx * radioX;
 23         float curY = (float)idy * radioY;
 24         int tmpP = (int)curX*mssWidth*mssBandCount + (int)curY;
 25         if(mssData[tmpP] == dfDstNoDataValue)
 26         {
 27             int i = 0;
 28             while(i < mssBandCount){
 29                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = dfDstNoDataValue;
 30                 i++;
 31             }
 32             return;
 33         }
 34         if(MethodType == 0){  // 最近邻
 35             int nearX = (int)(curX + 0.5)>(int)curX?(int)(curX + 1):(int)curX;
 36             int nearY = (int)(curY + 0.5)>(int)curY?(int)(curY + 1):(int)curY;
 37             if(nearX >= mssHeight - 1){
 38                 nearX = mssHeight - 1;
 39             }
 40             if(nearY >= mssWidth - 1){
 41                 nearY = mssWidth - 1;
 42             }
 43             if(nearX < mssHeight && nearY < mssWidth){
 44                 int i = 0;
 45                 while(i < mssBandCount){
 46                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
 47                         mssData[nearX*mssWidth*mssBandCount + i*mssWidth + nearY];
 48                     i++;
 49                 }
 50             }
 51         }
 52         if(MethodType == 1){  // 双线性
 53             float dataX = curX - (int)curX;
 54             float dataY = curY - (int)curY;
 55             if(dataX < eps){
 56                 dataX = 0.00001;
 57             }
 58             if(dataY < eps){
 59                 dataY = 0.00001;
 60             }
 61             int preX = (int)curX;
 62             int preY = (int)curY;
 63             int postX = (int)curX + 1;
 64             int postY = (int)curY + 1;
 65             if(postX >= mssHeight - 1){
 66                 postX = mssHeight - 1;
 67             }
 68             if(postY >= mssWidth - 1){
 69                 postY = mssWidth - 1;
 70             }
 71             
 72             float Wx1 = 1 - dataX;
 73             float Wx2 = dataX;
 74             float Wy1 = 1 - dataY;
 75             float Wy2 = dataY;
 76             // 双线性差值核心代码
 77             int i = 0;
 78             while(i < mssBandCount){
 79                 float pMssValue[4] = {0,0,0,0};
 80                 pMssValue[0] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + preY];
 81                 pMssValue[1] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + postY];
 82                 pMssValue[2] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + preY];
 83                 pMssValue[3] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + postY];
 84                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
 85                     Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[2]) + Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[3]);
 86                 i++;
 87             }
 88         }
 89         if(MethodType == 2){  // 双三次卷积
 90             float dataX = curX - (int)curX;
 91             float dataY = curY - (int)curY;
 92             //printf("dataX = %f   dataY = %f\n",dataX,dataY);
 93             int preX1 = (int)curX - 1;
 94             int preX2 = (int)curX;
 95             int postX1 = (int)curX + 1;
 96             int postX2 = (int)curX + 2;
 97             int preY1 = (int)curY - 1;
 98             int preY2 = (int)curY;
 99             int postY1 = (int)curY + 1;
100             int postY2 = (int)curY + 2;
101             if(preX1 < 0) preX1 = 0;
102             if(preY1 < 0) preY1 = 0;
103             if(postX1 > mssHeight - 1) postX1 = mssHeight - 1;
104             if(postX2 > mssHeight - 1) postX2 = mssHeight - 1;
105             if(postY1 > mssWidth - 1) postY1 = mssWidth - 1;
106             if(postY2 > mssWidth - 1) postY2 = mssWidth - 1;
107 
108             float Wx1 = -1.0f*dataX + 2*dataX*dataX - dataX*dataX*dataX;
109             float Wx2 = 1 - 2*dataX*dataX + dataX*dataX*dataX;
110             float Wx3 = dataX + dataX*dataX - dataX*dataX*dataX;
111             float Wx4 = -1.0f*dataX*dataX + dataX*dataX*dataX;
112             float Wy1 = -1.0f*dataY + 2*dataY*dataY - dataY*dataY*dataY;
113             float Wy2 = 1 - 2*dataY*dataY + dataY*dataY*dataY;
114             float Wy3 = dataY + dataY*dataY - dataY*dataY*dataY;
115             float Wy4 = -1.0f*dataY*dataY + dataY*dataY*dataY;
116             
117             //printf("preX1 = %d\n",preX1);
118             int i = 0;
119             while(i < mssBandCount){
120                 float pMssValue[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
121                 pMssValue[0] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY1];
122                 pMssValue[1] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY2];
123                 pMssValue[2] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY1];
124                 pMssValue[3] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY2];
125                 
126                 pMssValue[4] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY1];
127                 pMssValue[5] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY2];
128                 pMssValue[6] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY1];
129                 pMssValue[7] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY2];
130                 
131                 pMssValue[8] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY1];
132                 pMssValue[9] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY2];
133                 pMssValue[10] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY1];
134                 pMssValue[11] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY2];
135                 
136                 pMssValue[12] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY1];
137                 pMssValue[13] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY2];
138                 pMssValue[14] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY1];
139                 pMssValue[15] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY2];
140 
141                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
142                     Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[4] + Wx3*pMssValue[8] + Wx4*pMssValue[12])+
143                     Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[5] + Wx3*pMssValue[9] + Wx4*pMssValue[13])+
144                     Wy3*(Wx1*pMssValue[2] + Wx2*pMssValue[6] + Wx3*pMssValue[10] + Wx4*pMssValue[14])+
145                     Wy4*(Wx1*pMssValue[3] + Wx2*pMssValue[7] + Wx3*pMssValue[11] + Wx4*pMssValue[15]);
146                 i++;
147             }    
148         }
149     }
150 }

  以上代码应该可以直接使用,欢迎大家一起交流探讨。

 

另外,我对GDAL、CPU和GPU版本的重采样算法效率进行了一下对比,GPU在三次卷积重采样算法上要明显的比CPU版本效率高很多。具体结果如下:

posted on 2016-10-18 21:14  机器学习猪  阅读(5054)  评论(2编辑  收藏  举报