矩阵乘向量 基于哈希表的稀疏矩阵优化

计算N*N 方阵A 乘以 向量B

记录所需时间

CPU:I5 750 @3.6G

 

 

填充同样的 A 方阵数据 和 B向量 数据

 

N的规模

N = 1 << 12 = 4096

稀疏矩阵有效数据填充率:30%

demo1:

使用数组计算

time1 : 19ms

 

demo2:

使用Integer ,Float 包装类,二维稀疏矩阵

time2 : 83ms

 

demo3:

使用基本类型的int 和 float,并只将一维的数组用散列表替换,一维稀疏矩阵

行方向上,用一个数组

time3: 44ms

 

======

第二次尝试

提高N = 1 << 13  = 8192

稀疏矩阵有效数据填充率:30%

 

time : 66ms

time2:java.lang.OutOfMemoryError: GC overhead limit exceeded

time3: 150ms

 

======

第三次

N = 1 << 13  = 8192

稀疏矩阵有效数据填充率:5%

 

time1 : 61ms

time3: 47ms

 

====

第四次

稀疏矩阵有效数据填充率:2%

 


time1 : 62ms

time2 : 87ms

time3: 22ms

 

可以看到time3 得到了极大的提升,如果代码再优化的好一些,hash算法再快速一些,特别是对执行次数非常多语句,比如for中的语句,或循环嵌套中的语句做一些优化,时间成本和收益比会再初期非常大

 

google的 PageRank

使用了类似的方法减少时间复杂度,他的数据量是千亿级,但是矩阵中的稀疏度比较大,为了减少无意义的 0 x 0 和 sum += 0

但是google用的是一种图(不是本例中一样的数据结构)

 

suck code:

import java.util.Random;

public class Matrix {
    public static int N = 1 << 13;
    public static int matrixRandomSeed = 1234;
    public static float fillFactor = 0.98f;

    //一行中的一维数组做成基于 符号表的稀疏向量(set)
    //列: key 行号,val 一行的hashSet

    //优化:1.换数据结构,拉链,或者 红黑
    //2.使用基本类型值拷贝
    LinearProbingHashST<Integer, LinearProbingHashST<Integer, Float>> m1;

    public Matrix() {
        m1 = new LinearProbingHashST<>();
        randInit();
    }

    public void randInit() {
        Random rand = new Random(matrixRandomSeed);

        LinearProbingHashST<Integer, Float> row;
        for (int i = 0; i < N; i++) {
            m1.put(i, row = new LinearProbingHashST<Integer, Float>());
            for (int j = 0; j < N; j++) {
                float v = 0;
                if (rand.nextFloat() > fillFactor) { //填充%的随机数据
                    v = rand.nextFloat() + 0.1f;
                    row.put(j, v);
                }
            }
        }
    }


    public static float[] matrixDemo1() {
        float a[][] = new float[N][N];
        float b[];
        float c[] = new float[N];
        //a*b=c   b postMult a ,b后乘a

        Random rand = new Random(matrixRandomSeed);

        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                if (rand.nextFloat() > fillFactor) { //填充30%的随机数据
                    a[i][j] = rand.nextFloat() + 0.1f;
                }
            }
        }
        b = genB(N);

        long s1 = System.currentTimeMillis();
        for (int i = 0; i < N; i++) {
            c[i] = mult(a[i], b);
        }
        long s2 = System.currentTimeMillis();
        System.out.println("\ntime1 : " + (s2 - s1) + "ms\n");
        return c;
    }

    public static float mult(float a[], float b[]) {
        float sum = 0.0f;
        for (int i = 0; i < a.length; i++) {
            sum += a[i] * b[i];
        }
        return sum;
    }

    public static float[] genB(int n) {
        Random rand = new Random(matrixRandomSeed);
        float b[] = new float[n];
        for (int i = 0; i < n; i++) {
            b[i] = rand.nextFloat();
        }
        return b;
    }

    public static float[] matrixDemo2() {
        Matrix m = new Matrix();

        float[] b = genB(N);
        float[] c = m.mult(b);

        return c;
    }

    public float multRow(LinearProbingHashST<Integer, Float> r, float[] b) {
        float sum = 0;
        Node<Integer, Void> hl = r.headList.head;

        for (int index = 0; hl != null; ) {
            index = hl.k;
            float v = r.get(index);
            sum += b[index] * v;
            hl = hl.n;
        }
        return sum;
    }

    public float[] mult(float[] b) {
        Node<Integer, Void> hl = m1.headList.head;
        float[] c = new float[N];
        long s1 = System.currentTimeMillis();
        while (hl != null) {
            int rowN = hl.k;
            LinearProbingHashST<Integer, Float> row = m1.get(rowN);
            c[rowN] = multRow(row, b);
            hl = hl.n;
        }
        long s2 = System.currentTimeMillis();
        System.out.println("time2 : " + (s2 - s1) + "ms");
        return c;
    }


    public static boolean cmp(float[] a, float[] b) {
        if (a.length != b.length)
            return false;

        for (int i = 0; i < a.length; i++) {
            if (a[i] != b[i])
                return false;
        }

        return true;
    }


    public static SparseVector[] initm3() {
        SparseVector[] m3 = new SparseVector[N];
        Random rand = new Random(matrixRandomSeed);

        //填充矩阵a
        SparseVector row;
        for (int i = 0; i < N; i++) {
            row = m3[i] = new SparseVector();

            for (int j = 0; j < N; j++) {
                float v = 0;
                if (rand.nextFloat() > fillFactor) { //填充接近%的随机数据
                    v = rand.nextFloat() + 0.1f;
                    row.put(j, v);
                }
            }
        }
        return m3;
    }

    public static float mult(SparseVector row, float[] b) {
        IntNode intNode = row.headList.head; //keySet ,没用Iterator,直接用链表遍历
        float sum = 0;
        while (intNode != null) {
            int index = intNode.k;
            sum += row.get(index) * b[index];
            intNode = intNode.n;
        }
        return sum;
    }

    public static float[] matrixDemo3() {
        //init
        SparseVector[] m3a = Matrix.initm3();
        float[] b = genB(N);
        float[] c = new float[N];
        System.out.println();
        long s31 = System.currentTimeMillis();
        for (int i = 0; i < N; i++) {
            c[i] = mult(m3a[i], b);
        }
        long s32 = System.currentTimeMillis();
        System.out.println("time3: " + (s32 - s31) + "ms");
        return c;
    }

    public static void main(String args[]) {
        float[] res1, res2, res3;
        System.out.println("N " + N);
        res1 = matrixDemo1();

        System.out.println("\nmatrixDemo2\n");
        res2 = matrixDemo2();

        System.out.println("\nmatrixDemo3\n");
        res3 = matrixDemo3();
    }
}

 

其他的 More suck code就不贴了,太low太suck

posted on 2019-09-24 23:29  jald  阅读(546)  评论(0编辑  收藏  举报

导航