External Sorting Model
A refresher on Memory Hierarchy: due to increasingly striking performance gap between processor and memory, computer designers must take advantage of principles of locality to present users with as much memory as is available in the cheapest technology but at the speed offered by the fastest techonology. This desire heralds the so-called memory hierarchy, where storage at one level serves as a cache for storage at the next lower level. A picture by courtesy of CSAPP comes as follow:
In course CS359 Computer Architecture, we have learned four significant aspects regarding to cache issues: (1) Block Placement like set associative, (2) Block Identification (physical addr. = cache tag + cache index + block offset), (3) Block Replacement, such as LRU and FIFO, and (4) Write Strategies, such as write through on hit and no-allocate write on miss, or write back on hit and allocate write on miss. Given (1) and (2), we notice that one of the cache optimization tricks is to use virtual page offset to cover both cache index and block offset so that we can get access to TLB and cache simultaneously. For example, if we have 12-bit page offset and 6-bit block offset, we probably choose 6-bit L1 cache index. Also, we can get the following fomula:
cache capacity = (# sets) * (# ways) * block size, where typically block size = 64 bytes (6-bit offset).
The following program simulates a naive External Sorting process on 256 integers with a main memory that can only contain 8 integers. I wrote a class called Number to implement some basic number theory algorithms, one of which plays a key role in RSA encryption. I use Java TCP socket communication with RSA encryption to simulate data transfer between a disk and a memory.
1 import java.util.*; 2 import java.net.*; 3 import java.io.*; 4 5 class Number { 6 public static boolean isPrime(int k) { 7 for (int i=2;i*i<=k;i++) { 8 if (k%i==0) { 9 return false; 10 } 11 } 12 return true; 13 } 14 public static int modexp(int x,int n,int m) { 15 // Calculate Modular Exponential 16 // Precondition: x>=0 && n>=0 && m>0 17 // Postcondition: return y = (x^n)%m 18 if (n==0) { 19 return 1; 20 } else { 21 int val = modexp(x,n>>1,m); 22 val = val*val%m; 23 if ((n&1)!=0) { 24 val = val*x%m; 25 } 26 return val; 27 } 28 } 29 public static int gcd(int m,int n,int[] ref) { 30 // Calculate the greatest common divisor of m and n 31 // Precondition: n>=0 && m>0 && ref.length>=2 32 // Postcondition: the gcd of m and n is returned 33 // and gcd(m,n) = m*ref[0]+n*ref[1] 34 if (n==0) { 35 ref[0] = 1; 36 ref[1] = 0; 37 return m; 38 } else { 39 int val = gcd(n,m%n,ref); 40 int y = ref[1]; 41 ref[1] = ref[0]-m/n*y; 42 ref[0] = y; 43 return val; 44 } 45 } 46 public static int inverse(int x,int n) { 47 // Find the Multiplicative Inverse of x modulo n 48 // Precondition: x>0 && n>0 && x is relatively-prime to n 49 // Postcondition: y>=0 is returned such that x*y=1(mod n) 50 int[] ref = new int[2]; 51 if (gcd(x,n,ref)!=1) { 52 throw new RuntimeException("Error in Number.inverse()"); 53 } 54 while (ref[0]<0) { 55 ref[0] += n; 56 } 57 return ref[0]; 58 } 59 public static int modLinEqu(int a,int b,int n,int[] val) { 60 // Solve the Modular Linear Equation 61 // Precondition: val.length>=gcd(a,n) 62 // Postcondition: a*val[i]=b(mod n) i=0,1,2,...d-1 63 // and the number of solutions d is returned 64 int[] ref = new int [2]; 65 int d = gcd(a,n,ref); 66 if (b%d!=0) { 67 // No Solution 68 return 0; 69 } else { 70 val[0] = ref[0]*b/d%n; 71 if (val[0]<0) { 72 val[0] += n; 73 } 74 for (int i=1;i<d;i++) { 75 val[i] = (val[i-1]+n/d)%n; 76 } 77 return d; 78 } 79 } 80 public static int solveCRT(int k,int[] b,int[] n) { 81 // Chinese Remainder Theorem 82 // equation set: x=b[i](mod n[i]) i=0,1,2...k-1, 83 // where n[i]s are pairwise relatively-prime 84 // return the solution of the equation set 85 int val=0, m=1; 86 for (int i=0;i<k;i++) { 87 m *= n[i]; 88 } 89 for (int i=0;i<k;i++) { 90 val += m/n[i]*inverse(m/n[i],n[i])*b[i]; 91 val %= m; 92 } 93 if (val<0) { 94 val += m; 95 } 96 return val; 97 } 98 } 99 100 class RSA { 101 public int p, q, e, d; 102 103 public RSA(int p,int q,int e) { 104 this.p = p; // a prime 105 this.q = q; // another prime 106 this.e = e; // an odd relatively-prime to (p-1)*(q-1) 107 d = Number.inverse(e,(p-1)*(q-1)); 108 } 109 public int encode(int k) { 110 if (k<0||k>=p*q) { 111 throw new RuntimeException("Error in RSA.encode(int)"); 112 } 113 return Number.modexp(k,e,p*q); 114 } 115 public int decode(int k) { 116 if (k<0||k>=p*q) { 117 throw new RuntimeException("Error in RSA.decode(int)"); 118 } 119 return Number.modexp(k,d,p*q); 120 } 121 public void sendMsg(PrintWriter out,int k) { 122 out.println(encode(k)); 123 out.flush(); 124 } 125 public int getMsg(BufferedReader in) throws IOException { 126 return decode(Integer.parseInt(in.readLine())); 127 } 128 } 129 130 class MemManager extends Thread{ 131 public static int size = 8; 132 private int[] mem; 133 private int port; 134 private Socket socket; 135 private BufferedReader in; 136 private PrintWriter out; 137 private RSA rsa; 138 139 public MemManager(int port,RSA rsa) { 140 mem = new int[size]; 141 this.port = port; 142 this.rsa = rsa; 143 start(); 144 } 145 public void run() { 146 try { 147 ServerSocket server = new ServerSocket(port); 148 socket = server.accept(); 149 in = new BufferedReader(new InputStreamReader(socket.getInputStream())); 150 out = new PrintWriter(socket.getOutputStream()); 151 int num = rsa.getMsg(in); 152 while (num>0) { 153 // Protocol: receive an integer num following num input integers 154 // return the minimum's index following the num sorted integers 155 // the communication terminates on condition that num==0 156 for (int i=0;i<num;i++) { 157 mem[i] = rsa.getMsg(in); 158 } 159 int minIdx = 0; 160 for (int i=1;i<num;i++) { 161 if (mem[i]<mem[minIdx]) { 162 minIdx = i; 163 } 164 } 165 Arrays.sort(mem,0,num); 166 rsa.sendMsg(out,minIdx); 167 for (int i=0;i<num;i++) { 168 rsa.sendMsg(out,mem[i]); 169 } 170 num = rsa.getMsg(in); 171 } 172 out.close(); 173 in.close(); 174 socket.close(); 175 server.close(); 176 } catch (Exception except) { 177 System.err.println("Error: "+except); 178 } 179 } 180 } 181 182 class DiskManager extends Thread { 183 public static int INF; 184 private int num = 256; 185 private int[] size; 186 private int[] fib; 187 private int[][] disk; 188 private int port; 189 private Socket socket; 190 private BufferedReader in; 191 private PrintWriter out; 192 private RSA rsa; 193 194 public DiskManager(int port,RSA rsa) { 195 INF = rsa.p*rsa.q-1; 196 size = new int[3]; 197 fib = new int[3]; 198 disk = new int[3][2048]; 199 this.port = port; 200 this.rsa = rsa; 201 start(); 202 } 203 public void run() { 204 try { 205 socket = new Socket("localhost",port); 206 in = new BufferedReader(new InputStreamReader(socket.getInputStream())); 207 out = new PrintWriter(socket.getOutputStream()); 208 Random rand = new Random(); 209 for (int i=0;i<num;i++) { 210 disk[0][i] = rand.nextInt(512); 211 } 212 showRes(exterSort()); 213 out.close(); 214 in.close(); 215 socket.close(); 216 } catch (Exception except) { 217 System.err.println("Error: "+except); 218 } 219 } 220 private int exterSort() throws IOException { 221 int memSize = MemManager.size; 222 int tape = preprocess(memSize); 223 for (int i=0;i<fib[(tape+2)%3]*memSize;i++) { 224 disk[(tape+2)%3][i] = disk[tape][i]; 225 } 226 for (int i=0;i<fib[(tape+1)%3]*memSize;i++) { 227 disk[(tape+1)%3][i] = disk[tape][i+fib[(tape+2)%3]*memSize]; 228 } 229 size[(tape+2)%3] = memSize; 230 size[(tape+1)%3] = memSize; 231 tape = merge((tape+2)%3,(tape+1)%3); 232 rsa.sendMsg(out,0); // signal the server to quit 233 return tape; 234 } 235 private int preprocess(int memSize) throws IOException { 236 // Get fib[tape] sorted segments in disk[tape] 237 // and then return the index tape 238 int segNum = (num+memSize-1)/memSize; 239 for (int i=num;i<segNum*memSize;i++) { 240 disk[0][i] = INF; // set sentinels 241 } 242 int tape = genFib(segNum); 243 for (int i=0;i<segNum;i++) { 244 rsa.sendMsg(out,memSize); 245 for (int j=0;j<memSize;j++) { 246 rsa.sendMsg(out,disk[0][i*memSize+j]); 247 } 248 rsa.getMsg(in); // minimum's index 249 for (int j=0;j<memSize;j++) { 250 disk[tape][i*memSize+j] = rsa.getMsg(in); 251 } 252 } 253 for (int i=segNum;i<fib[tape];i++) { 254 for (int j=0;j<memSize;j++) { 255 disk[tape][i*memSize+j] = INF; 256 } 257 } 258 System.out.println("\nPreprocessing Completes."); 259 return tape; 260 } 261 private int genFib(int segNum) { 262 // Return pos such that fib[pos] wil be the minimum 263 // Fibonacci number not less than segNum 264 fib[1] = 1; 265 int pos = 1; 266 while (fib[pos]<segNum) { 267 pos = (pos+1)%3; 268 fib[pos] = fib[(pos+2)%3]+fib[(pos+1)%3]; 269 } 270 return pos; 271 } 272 private int merge(int tape1,int tape2) throws IOException { 273 // Merge disk[tape1] with disk[tape2] and store them into disk[3-tape1-tape2] 274 if (fib[tape2]>0) { 275 System.out.print("Merging "+fib[tape1]+"\t segments with \t"); 276 System.out.println(fib[tape2]+"\tsegments ..."); 277 int[] pos = new int[3]; 278 int cnt = 0; 279 while (cnt<fib[tape2]) { 280 mergeHelp(++cnt,pos,tape1,tape2); 281 } 282 for (int i=pos[tape1];i<fib[tape1]*size[tape1];i++) { 283 disk[tape1][i-pos[tape1]] = disk[tape1][i]; 284 } 285 size[3-tape1-tape2] = size[tape1]+size[tape2]; 286 fib[3-tape1-tape2] = fib[tape2]; 287 fib[tape1] -= fib[tape2]; 288 return merge(3-tape1-tape2,tape1); 289 } else { 290 System.out.println("Merging Completes.\n"); 291 return tape1; 292 } 293 } 294 private void mergeHelp(int k,int[] pos,int tape1,int tape2) throws IOException { 295 // Extract one segment from both disk[tape1] and disk[tape2] respectively, 296 // merge them and store them into disk[3-tape1-tape2] 297 int tape = 3 - tape1 - tape2; 298 while(pos[tape1]<k*size[tape1]&&pos[tape2]<k*size[tape2]) { 299 rsa.sendMsg(out,2); 300 rsa.sendMsg(out,disk[tape1][pos[tape1]]); 301 rsa.sendMsg(out,disk[tape2][pos[tape2]]); 302 if (rsa.getMsg(in)==0) { 303 pos[tape1]++; 304 } else { 305 pos[tape2]++; 306 } 307 disk[tape][pos[tape]++] = rsa.getMsg(in); 308 rsa.getMsg(in); 309 } 310 while (pos[tape1]<k*size[tape1]) { 311 disk[tape][pos[tape]++] = disk[tape1][pos[tape1]++]; 312 } 313 while (pos[tape2]<k*size[tape2]) { 314 disk[tape][pos[tape]++] = disk[tape2][pos[tape2]++]; 315 } 316 } 317 private void showRes(int tape) { 318 System.out.println("Sorting Result:"); 319 int pos = 0; 320 for (int i=0;i<num;i++) { 321 if (pos>0) { 322 System.out.print("\t\t"); 323 } 324 System.out.print(disk[tape][i]); 325 pos = (pos+1)&15; // 16 integers a line 326 if (pos==0) { 327 System.out.println(); 328 } 329 } 330 if (pos>0) { 331 System.out.println(); 332 } 333 } 334 } 335 336 public class Main { 337 public static void main(String[] args) { 338 System.out.println("This program may take some time."); 339 System.out.println("Thanks for your patience!\t:-)"); 340 // p = 97, q = 101, e = 9599 341 RSA rsa = new RSA(97,101,9599); 342 MemManager mem = new MemManager(10086,rsa); 343 DiskManager disk = new DiskManager(10086,rsa); 344 try { 345 disk.join(); 346 mem.join(); 347 } catch (Exception except) { 348 System.err.println("Error: "+except); 349 } 350 } 351 }