cuda标准差拉伸

标准差拉伸(tif影像波段值类型由16bit转为8bit)cuda实现版本

使用gdal2.4.4,cuda10.1 ,thrust库(计算波段均值、方差值)

  1. 使用 gdal2.4.4 读取 GTiff 格式影像,读取数据至数组

  2. 使用 thrust库计算 最大值、最小值、波段均值、方差等

  3. cuda10.1 核函数执行条件判断赋值

头文件引用

  • thrust计算最大值、最小值引用 

#include "thrust/extrema.h"

  • 设备指针

#include "thrust/device_vector.h"

  • thrust 可以在 cpu 和 gpu 端执行

#include "thrust/execution_policy.h"

通过调用函数的第一个参数指定 thrust::reduce(thrust::host,thrust:device
  • 在累加求和时注意总和值类型,数组类型为 unsigned short ,求和后会远远超过 该类型最大值,故在
auto band_sum = thrust::reduce(ptr, ptr + size, (ull)0); 指定计算类型为 unsigned long long

代码如下:

复制代码
  1 #include "cuda_runtime.h"
  2 #include "device_launch_parameters.h"
  3 
  4 #include "thrust/host_vector.h"
  5 #include "thrust/device_vector.h"
  6 #include "thrust/extrema.h"
  7 #include "thrust/reduce.h"
  8 #include "thrust/functional.h"
  9 #include "thrust/execution_policy.h"
 10 
 11 #include "gdal_util.h"
 12 #include "cpl_conv.h"
 13 
 14 // 求方差
 15 struct variance : std::unary_function<us, double>
 16 {
 17     variance(double m) : mean(m) { }
 18     const double mean;
 19     __host__ __device__ double operator()(us data) const
 20     {
 21         return std::pow(data - mean, 2.0);
 22     }
 23 };
 24 
 25 __global__ void pixels_std(us* data, uc* res, const ull size, us band_max, us band_min, us uc_max, us uc_min, float k, float b)
 26 {
 27     ui tid = threadIdx.x + blockDim.x * blockIdx.x;
 28     if (tid >= size) return;
 29 
 30     const us d = data[tid];
 31     us v;
 32     if (d == band_min)
 33         v = band_min;
 34     else if (d <= uc_min)
 35         v = band_min;
 36     else if (d >= uc_max)
 37         v = band_max;
 38     else if (k * d + b < band_min)
 39         v = band_min;
 40     else if (k * d + b > band_max)
 41         v = band_max;
 42     else if (k * d + b > band_min && k * d + b < band_max)
 43         v = k * d + b;
 44     else
 45         v = d;
 46 
 47     res[tid] = static_cast<uc>(v);
 48 }
 49 
 50 int main(int argc, char* argv[])
 51 {
 52     // 16bit 转 8bit
 53     GDALAllRegister();
 54 
 55     char psz_filename[1024] = "D:\\cuda\\PAN31.TIF";
 56     char psz_filename_new[1024] = "D:\\cuda\\PANNew.TIF";
 57     // GetGDALDriverManager()->AutoLoadDrivers();
 58     GDALDriver* tifDriver = GetGDALDriverManager()->GetDriverByName("GTiff");
 59     raster_info ri;
 60 
 61     CPLSetConfigOption("GDAL_FILENAME_IS_UTF8", "NO");
 62     GDALDatasetH dataset_uint16 = GDALOpen(psz_filename, GA_Update);
 63     
 64     // if (dataset_uint16 == NULL)
 65     get_raster_info(dataset_uint16, &ri);
 66     //  新影像
 67     GDALDataset* dataset_uint8 = tifDriver->Create(psz_filename_new, ri.width, ri.height, GDALGetRasterCount(dataset_uint16), GDT_Byte,NULL);
 68     dataset_uint8->SetGeoTransform(ri.geo_transform);
 69     dataset_uint8->SetProjection(ri.projection);
 70 
 71     printf("Size is %dx%dx%d\n",
 72         ri.width,
 73         ri.height,
 74         GDALGetRasterCount(dataset_uint8));
 75     printf("Pixel Size = (%.6f,%.6f)\n",
 76         ri.geo_transform[1], ri.geo_transform[5]);
 77 
 78     cudaError_t status;
 79     GDALRasterBandH h_band;
 80     GDALRasterBandH h_band2;
 81     const int x_size = ri.width;
 82     const int y_size = ri.height;
 83     const ull size = x_size * y_size;
 84     const ull malloc_size = sizeof(us) * x_size * y_size;
 85 
 86     // 原影像
 87     us* h_data;
 88     // 新影像
 89     uc* h_res;
 90     h_data = (us*)CPLMalloc(malloc_size);
 91     h_res = (uc*)CPLMalloc(size);
 92     us* d_data;
 93     uc* d_res;
 94     status = cudaMalloc((void**)&d_data, malloc_size);
 95     status = cudaMalloc((void**)&d_res, size);
 96     for (int i = 0; i < 3; ++i)
 97     {
 98         h_band = GDALGetRasterBand(dataset_uint16, i + 1);
 99         h_band2 = GDALGetRasterBand(dataset_uint8, i + 1);
100         GDALRasterIO(h_band, GF_Read, 0, 0, x_size, y_size,
101             h_data, x_size, y_size, GDT_UInt16, 0, 0);
102         status = cudaMemcpy(d_data, h_data, malloc_size, cudaMemcpyHostToDevice);
103         thrust::device_ptr<us> ptr(d_data);
104         // 数组越界时抛出 msg:extrema failed to synchronize
105         // 最大值最小值仅为测试
106         const auto max_iter = thrust::max_element(ptr, ptr + size);
107         const auto min_iter = thrust::min_element(ptr, ptr + size);
108         us band_max = *max_iter;
109         us band_min = *min_iter;
110         band_max = 255;
111         band_min = 0;
112         // cpu 执行
113         // auto band_sum_cpu = thrust::reduce(thrust::host, h_data, h_data + size, (ull)0);
114         // gpu 执行
115         // 此处总和值类型使用 unsigned long long 
116         auto band_sum = thrust::reduce(ptr, ptr + size, (ull)0);
117         double band_mean = band_sum / (double)size;
118         // 方差 (val-mean)*(val-mean)
119         auto band_std2 = thrust::transform_reduce(ptr, ptr + size, variance(band_mean), (double)0, thrust::plus<double>());
120         double band_std = std::sqrt(band_std2/(double)(size-1));
121         // 2.5倍标准差
122         float kn = 2.5;
123         float uc_max = band_mean + kn * band_std;
124         float uc_min = band_mean - kn * band_std;
125         float k = (band_max - band_min) / (uc_max - uc_min);
126         float b = (uc_max * band_min - uc_min * band_max) / (uc_max - uc_min);
127         if (uc_min <= 0)
128             uc_min = 0;
129             
130         const ui block_size = 128;
131         const ui grid_size = (size - 1) / block_size + 1;
132         pixels_std << <grid_size, block_size >> > (d_data, d_res, size, band_max, band_min, uc_max, uc_min, k, b);
133 
134         cudaDeviceSynchronize();
135 
136         cudaMemcpy(h_res, d_res, size, cudaMemcpyDeviceToHost);
137         //
138         GDALRasterIO(h_band2, GF_Write, 0, 0, x_size, y_size,
139             h_res, x_size, y_size, GDT_Byte,0, 0);
140     }
141     cudaFree(d_data);
142     cudaFree(d_res);
143     CPLFree(h_data);
144     CPLFree(h_res);
145     
146     GDALClose(dataset_uint16);
147     GDALClose(dataset_uint8);
148     
149     return 0;
150 }
复制代码

python版本 使用 gdal+numpy实现,GitHub链接:

python实现版本

 

posted @   zgcx  阅读(224)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示