upc-9541 矩阵乘法 (矩阵分块)

题目描述

深度学习算法很大程度上基于矩阵运算。例如神经网络中的全连接本质上是一个矩阵乘法,而卷积运算也通常是用矩阵乘法来实现的。有一些科研工作者为了让神经网络的计算更快捷,提出了二值化网络的方法,就是将网络权重压缩成只用两种值表示的形式,这样就可以用一些 trick 加速计算了。例如两个二进制向量点乘,可以用计算机中的与运算代替,然后统计结果中 1 的个数即可。
然而有时候为了降低压缩带来的误差,只允许其中一个矩阵被压缩成二进制。这样的情况下矩阵乘法运算还能否做进一步优化呢?给定一个整数矩阵A 和一个二值矩阵B,计算矩阵乘法 C=A×B。为了减少输出,你只需要计算 C 中所有元素的的异或和即可。
 

 

输入

第一行有三个整数 N,P,M, 表示矩阵 A,B 的大小分别是 N×P,P×M 。
接下来 N 行是矩阵 A 的值,每一行有 P 个数字。第 i+1 行第 j 列的数字为 Ai,j, Ai,j 用大写的16进制表示(即只包含 0~9, A~F),每个数字后面都有一个空格。
接下来 M 行是矩阵 B 的值,每一行是一个长为 P 的 01字符串。第 i+N+1 行第 j 个字符表示 Bj,i 的值。
 

 

输出

一个整数,矩阵 C 中所有元素的异或和。

 

样例输入

4 2 3
3 4
8 A
F 5
6 7
01
11
10

 

样例输出

2

 

提示

2≤N,M≤4096,1≤P≤64,0≤Ai,j<65536,0≤Bi,j≤1.

 
看起来是矩阵分块,但是数据比较水,for for for暴力循环就能过题。
 
由于矩阵b是个01矩阵,所以如果按8位分块,一块最多有256种情况,预处理分块后极限数据时间复杂度为5e8.
#include "bits/stdc++.h"

using namespace std;
const int maxn = 4100;
int a[maxn][70], b[maxn][70];
int ap[maxn][10][260], bp[maxn][10];

int main() {
    //freopen("input.txt", "r", stdin);
    int n, p, m;
    scanf("%d %d %d", &n, &p, &m);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < p; j++) {
            scanf("%x", &a[i][j]);
        }
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < p; j++) {
            scanf("%1d", &b[i][j]);
        }
    }
    p = (p + 7) / 8;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < p; j++) {
            int base = j * 8;
            for (int k = 0; k < 256; k++) {
                for (int l = 0; l < 8; l++) {
                    if (k & (1 << l)) {
                        ap[i][j][k] += a[i][base + l];
                    }
                }
            }
        }
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < p; j++) {
            int base = j * 8;
            for (int k = 0; k < 8; k++) {
                bp[i][j] += (b[i][base + k] << k);
            }
        }
    }
    int ans = 0, temp;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            temp = 0;
            for (int k = 0; k < p; k++) {
                temp += ap[i][k][bp[j][k]];
            }
            ans ^= temp;
        }
    }
    printf("%d\n", ans);
    return 0;
}

 

posted @ 2018-10-10 15:37  Albert_liu  阅读(1344)  评论(0编辑  收藏  举报