矩阵乘向量 基于哈希表的稀疏矩阵优化
计算N*N 方阵A 乘以 向量B
记录所需时间
CPU:I5 750 @3.6G
填充同样的 A 方阵数据 和 B向量 数据
N的规模
N = 1 << 12 = 4096
稀疏矩阵有效数据填充率:30%
demo1:
使用数组计算
time1 : 19ms
demo2:
使用Integer ,Float 包装类,二维稀疏矩阵
time2 : 83ms
demo3:
使用基本类型的int 和 float,并只将一维的数组用散列表替换,一维稀疏矩阵
行方向上,用一个数组
time3: 44ms
======
第二次尝试
提高N = 1 << 13 = 8192
稀疏矩阵有效数据填充率:30%
time : 66ms
time2:java.lang.OutOfMemoryError: GC overhead limit exceeded
time3: 150ms
======
第三次
N = 1 << 13 = 8192
稀疏矩阵有效数据填充率:5%
time1 : 61ms
time3: 47ms
====
第四次
稀疏矩阵有效数据填充率:2%
time1 : 62ms
time2 : 87ms
time3: 22ms
可以看到time3 得到了极大的提升,如果代码再优化的好一些,hash算法再快速一些,特别是对执行次数非常多语句,比如for中的语句,或循环嵌套中的语句做一些优化,时间成本和收益比会再初期非常大
google的 PageRank
使用了类似的方法减少时间复杂度,他的数据量是千亿级,但是矩阵中的稀疏度比较大,为了减少无意义的 0 x 0 和 sum += 0
但是google用的是一种图(不是本例中一样的数据结构)
suck code:
import java.util.Random; public class Matrix { public static int N = 1 << 13; public static int matrixRandomSeed = 1234; public static float fillFactor = 0.98f; //一行中的一维数组做成基于 符号表的稀疏向量(set) //列: key 行号,val 一行的hashSet //优化:1.换数据结构,拉链,或者 红黑 //2.使用基本类型值拷贝 LinearProbingHashST<Integer, LinearProbingHashST<Integer, Float>> m1; public Matrix() { m1 = new LinearProbingHashST<>(); randInit(); } public void randInit() { Random rand = new Random(matrixRandomSeed); LinearProbingHashST<Integer, Float> row; for (int i = 0; i < N; i++) { m1.put(i, row = new LinearProbingHashST<Integer, Float>()); for (int j = 0; j < N; j++) { float v = 0; if (rand.nextFloat() > fillFactor) { //填充%的随机数据 v = rand.nextFloat() + 0.1f; row.put(j, v); } } } } public static float[] matrixDemo1() { float a[][] = new float[N][N]; float b[]; float c[] = new float[N]; //a*b=c b postMult a ,b后乘a Random rand = new Random(matrixRandomSeed); for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { if (rand.nextFloat() > fillFactor) { //填充30%的随机数据 a[i][j] = rand.nextFloat() + 0.1f; } } } b = genB(N); long s1 = System.currentTimeMillis(); for (int i = 0; i < N; i++) { c[i] = mult(a[i], b); } long s2 = System.currentTimeMillis(); System.out.println("\ntime1 : " + (s2 - s1) + "ms\n"); return c; } public static float mult(float a[], float b[]) { float sum = 0.0f; for (int i = 0; i < a.length; i++) { sum += a[i] * b[i]; } return sum; } public static float[] genB(int n) { Random rand = new Random(matrixRandomSeed); float b[] = new float[n]; for (int i = 0; i < n; i++) { b[i] = rand.nextFloat(); } return b; } public static float[] matrixDemo2() { Matrix m = new Matrix(); float[] b = genB(N); float[] c = m.mult(b); return c; } public float multRow(LinearProbingHashST<Integer, Float> r, float[] b) { float sum = 0; Node<Integer, Void> hl = r.headList.head; for (int index = 0; hl != null; ) { index = hl.k; float v = r.get(index); sum += b[index] * v; hl = hl.n; } return sum; } public float[] mult(float[] b) { Node<Integer, Void> hl = m1.headList.head; float[] c = new float[N]; long s1 = System.currentTimeMillis(); while (hl != null) { int rowN = hl.k; LinearProbingHashST<Integer, Float> row = m1.get(rowN); c[rowN] = multRow(row, b); hl = hl.n; } long s2 = System.currentTimeMillis(); System.out.println("time2 : " + (s2 - s1) + "ms"); return c; } public static boolean cmp(float[] a, float[] b) { if (a.length != b.length) return false; for (int i = 0; i < a.length; i++) { if (a[i] != b[i]) return false; } return true; } public static SparseVector[] initm3() { SparseVector[] m3 = new SparseVector[N]; Random rand = new Random(matrixRandomSeed); //填充矩阵a SparseVector row; for (int i = 0; i < N; i++) { row = m3[i] = new SparseVector(); for (int j = 0; j < N; j++) { float v = 0; if (rand.nextFloat() > fillFactor) { //填充接近%的随机数据 v = rand.nextFloat() + 0.1f; row.put(j, v); } } } return m3; } public static float mult(SparseVector row, float[] b) { IntNode intNode = row.headList.head; //keySet ,没用Iterator,直接用链表遍历 float sum = 0; while (intNode != null) { int index = intNode.k; sum += row.get(index) * b[index]; intNode = intNode.n; } return sum; } public static float[] matrixDemo3() { //init SparseVector[] m3a = Matrix.initm3(); float[] b = genB(N); float[] c = new float[N]; System.out.println(); long s31 = System.currentTimeMillis(); for (int i = 0; i < N; i++) { c[i] = mult(m3a[i], b); } long s32 = System.currentTimeMillis(); System.out.println("time3: " + (s32 - s31) + "ms"); return c; } public static void main(String args[]) { float[] res1, res2, res3; System.out.println("N " + N); res1 = matrixDemo1(); System.out.println("\nmatrixDemo2\n"); res2 = matrixDemo2(); System.out.println("\nmatrixDemo3\n"); res3 = matrixDemo3(); } }
其他的 More suck code就不贴了,太low太suck