三种快速排序的效率对比(普通、多线程、amp)

参照:http://www.codeproject.com/Articles/543451/Parallel-Radix-Sort-on-the-GPU-using-Cplusplus-AMP

对于普通PC电脑而言,在数据量较小时,多线程优于GPU加速;数据量较大时,GPU加速优于多线程。

main.cpp

 1 #include <amp.h>
 2 #include <chrono>
 3 #include <algorithm>
 4 #include <conio.h>
 5 #include "radix_sort.h"
 6 #include <ppl.h>
 7 
 8 
 9 int main()
10 {
11     using namespace concurrency;
12     accelerator default_device;
13     wprintf(L"Using device : %s\n\n", default_device.get_description());
14     if (default_device == accelerator(accelerator::direct3d_ref))
15         printf("WARNING!! Running on very slow emulator! Only use this accelerator for debugging.\n\n");
16 
17     for(uint i = 0; i < 10; i ++)
18     {
19         uint num = (1<<(i+10));
20         printf("Testing for %u elements: \n", num);
21 
22         std::vector<uint> data(num);
23 
24         for(uint i = 0; i < num; i ++)
25         {
26             data[i] = i;
27         }
28         std::random_shuffle(data.begin(), data.end());
29         std::vector<uint> dataclone(data.begin(), data.end());
30 
31         auto start_fill = std::chrono::high_resolution_clock::now();
32         array<uint> av(num, data.begin(), data.end());
33         auto end_fill = std::chrono::high_resolution_clock::now();
34 
35         printf("Allocating %u random unsigned integers complete! Start GPU sort.\n", num);
36 
37         auto start_comp = std::chrono::high_resolution_clock::now();
38         pal::radix_sort(av);
39         av.accelerator_view.wait();        //Wait for the computation to finish
40         auto end_comp = std::chrono::high_resolution_clock::now();
41 
42         auto start_collect = std::chrono::high_resolution_clock::now();
43         data = av; //synchronise
44         auto end_collect = std::chrono::high_resolution_clock::now();
45 
46         printf("GPU sort completed in %llu microseconds.\nData transfer: %llu microseconds, computation: %llu microseconds\n",
47             std::chrono::duration_cast<std::chrono::microseconds> (end_collect-start_fill).count(), 
48             std::chrono::duration_cast<std::chrono::microseconds> (end_fill-start_fill+end_collect-start_collect).count(),
49             std::chrono::duration_cast<std::chrono::microseconds> (end_comp-start_comp).count());
50 
51         printf("Testing for correctness. Results are.. ");
52 
53         uint success = 1;
54         for(uint i = 0; i < num; i ++)
55         {
56             if(data[i] != i) { success = 0; break;}
57         }
58         printf("%s\n", (success? "correct!" : "incorrect!"));
59 
60         data = dataclone;
61         printf("Beginning CPU sorts for comparison.\n");
62         start_comp = std::chrono::high_resolution_clock::now();
63         std::sort(data.data(), data.data()+num);
64         end_comp = std::chrono::high_resolution_clock::now();
65         printf("CPU std::sort completed in %llu microseconds. \n", std::chrono::duration_cast<std::chrono::microseconds>(end_comp-start_comp).count());
66 
67         data = dataclone;
68         start_comp = std::chrono::high_resolution_clock::now();
69         //Note: the concurrency::parallel sorts are horribly slow if you give them vectors (i.e. parallel_radixsort(data.begin(), data.end())
70         concurrency::parallel_radixsort(data.data(), data.data()+num);
71         end_comp = std::chrono::high_resolution_clock::now();
72         printf("CPU concurrency::parallel_sort completed in %llu microseconds. \n\n\n", std::chrono::duration_cast<std::chrono::microseconds>(end_comp-start_comp).count());
73 
74     }
75 
76     printf("Press any key to exit! \n");
77     _getch();
78 }

 

radix_sort.h

1 # pragma once
2 typedef unsigned int uint;
3 #include <amp.h>
4 
5 namespace pal
6 {
7     void radix_sort(uint* start,  uint num);
8     void radix_sort(concurrency::array<uint>& arr);
9 }

 

readix_sort.cpp

  1 #include <amp.h>
  2 #include "radix_sort.h"
  3 
  4 
  5 void arr_fill(concurrency::array_view<uint> &dest, concurrency::array_view<uint>& src,  uint val) 
  6 { 
  7     parallel_for_each(dest.extent,[dest ,val, src](concurrency::index<1> idx)restrict(amp)
  8     {
  9         dest[idx] = ( (uint)idx[0] <src.get_extent().size())? src[idx]: val; 
 10     }); 
 11 }
 12 
 13 uint get_bits(uint x, uint numbits, uint bitoffset) restrict(amp)
 14 {
 15     return  (x>>bitoffset) & ~(~0 <<numbits);
 16 }
 17 
 18 uint pow2(uint x) restrict(amp,cpu)
 19 {
 20     return ( ((uint)1) << x);
 21 }
 22 
 23 uint tile_sum(uint x, concurrency::tiled_index<256> tidx) restrict(amp)
 24 {
 25     using namespace concurrency;
 26     uint l_id = tidx.local[0];
 27     tile_static uint l_sums[256][2];
 28         
 29     l_sums[l_id][0] = x;
 30     tidx.barrier.wait();
 31 
 32     for(uint i = 0; i < 8; i ++)
 33     {
 34         if(l_id<  pow2(7-i))
 35         {
 36             uint w = (i+1)%2;
 37             uint r = i%2;
 38 
 39             l_sums[l_id][w] = l_sums[l_id*2][r] + l_sums[l_id*2 +1][r];
 40         }
 41         tidx.barrier.wait();
 42     }
 43     return l_sums[0][0];
 44         
 45 }
 46 
 47 uint tile_prefix_sum(uint x, concurrency::tiled_index<256> tidx,  uint& last_val ) restrict(amp)
 48 {
 49     using namespace concurrency;
 50     uint l_id = tidx.local[0];
 51     tile_static uint l_prefix_sums[256][2];
 52 
 53     l_prefix_sums[l_id][0] = x;
 54     tidx.barrier.wait();
 55 
 56     for(uint i = 0; i < 8; i ++)
 57     {
 58         uint pow2i = pow2(i);
 59 
 60         uint w = (i+1)%2;
 61         uint r = i%2;
 62             
 63         l_prefix_sums[l_id][w] = (l_id >= pow2i)? ( l_prefix_sums[l_id][r] + l_prefix_sums[l_id - pow2i][r]) : l_prefix_sums[l_id][r] ;
 64         
 65         tidx.barrier.wait();
 66     }
 67     last_val = l_prefix_sums[255][0];
 68     
 69     uint retval = (l_id ==0)? 0: l_prefix_sums[l_id -1][0];
 70     return retval;
 71 }
 72 
 73 uint tile_prefix_sum(uint x, concurrency::tiled_index<256> tidx) restrict(amp)
 74 {
 75     uint ll=0;
 76     return tile_prefix_sum(x, tidx, ll);
 77 }
 78 
 79 
 80 void calc_interm_sums(uint bitoffset, concurrency::array<uint> & interm_arr, 
 81                       concurrency::array<uint> & interm_sums, concurrency::array<uint> & interm_prefix_sums, uint num_tiles)
 82 {
 83     using namespace concurrency;
 84     auto ext = extent<1>(num_tiles*256).tile<256>();
 85 
 86     parallel_for_each(ext , [=, &interm_sums, &interm_arr](tiled_index<256> tidx) restrict(amp)
 87     {
 88         uint inbound = ((uint)tidx.global[0]<interm_arr.get_extent().size());
 89         uint num = (inbound)? get_bits(interm_arr[tidx.global[0]], 2, bitoffset): get_bits(0xffffffff, 2, bitoffset);
 90         for(uint i = 0; i < 4; i ++)
 91         {
 92             uint to_sum = (num == i);
 93             uint sum = tile_sum(to_sum, tidx);
 94 
 95             if(tidx.local[0] == 0)
 96             {
 97                 interm_sums[i*num_tiles + tidx.tile[0]] = sum;
 98             }
 99         }
100 
101     });
102     
103     uint numiter = (num_tiles/64) + ((num_tiles%64 == 0)? 0:1);
104     ext = extent<1>(256).tile<256>();
105     parallel_for_each(ext , [=, &interm_prefix_sums, &interm_sums](tiled_index<256> tidx) restrict(amp)
106     {
107         uint last_val0 = 0;
108         uint last_val1 = 0;
109     
110         for(uint i = 0; i < numiter; i ++)
111         {
112             uint g_id = tidx.local[0] + i*256;
113             uint num = (g_id<(num_tiles*4))? interm_sums[g_id]: 0;
114             uint scan = tile_prefix_sum(num, tidx, last_val0);
115             if(g_id<(num_tiles*4)) interm_prefix_sums[g_id] = scan + last_val1;
116 
117             last_val1 += last_val0;
118         }
119 
120     });
121 }
122 
123 void sort_step(uint bitoffset, concurrency::array<uint> & src, concurrency::array<uint> & dest,  
124                concurrency::array<uint> & interm_prefix_sums, uint num_tiles)
125 {
126     using namespace concurrency;
127     auto ext = extent<1>(num_tiles*256).tile<256>();
128 
129     parallel_for_each(ext , [=, &interm_prefix_sums, &src, &dest](tiled_index<256> tidx) restrict(amp)
130     {
131         uint inbounds = ((uint)tidx.global[0]<src.get_extent().size());
132         uint element = (inbounds)? src[tidx.global[0]] : 0xffffffff;
133         uint num = get_bits(element, 2,bitoffset);
134         for(uint i = 0; i < 4; i ++)
135         {
136             uint scan = tile_prefix_sum((num == i), tidx) + interm_prefix_sums[i*num_tiles + tidx.tile[0]];
137             if(num==i && inbounds) dest[scan] = element;
138         }
139 
140     });
141 }
142 
143 namespace pal
144 {
145     void radix_sort(concurrency::array<uint>& arr)
146     {
147         using namespace concurrency;
148         uint size = arr.get_extent().size();
149 
150         const uint num_tiles = (size/256) + ((size%256 == 0)? 0:1);
151 
152         array<uint> interm_arr(size);
153         array<uint> interm_sums(num_tiles*4);
154         array<uint> interm_prefix_sums(num_tiles*4);
155 
156         for(uint i = 0; i < 16; i ++)
157         {
158             array<uint>& src  = (i%2==0)? arr: interm_arr;
159             array<uint>& dest = (i%2==0)? interm_arr: arr;
160 
161             uint bitoffset = i*2;
162             calc_interm_sums(bitoffset, src, interm_sums, interm_prefix_sums, num_tiles);
163             sort_step(bitoffset, src, dest, interm_prefix_sums, num_tiles);
164         }
165     }
166 
167     void radix_sort(uint* arr,  uint size)
168     {
169         radix_sort(concurrency::array<uint>(size, arr));
170     }
171 }

 

posted @ 2015-06-01 09:31  -学以致用-  阅读(827)  评论(0编辑  收藏  举报