矩阵乘法写法比较

免责声明

测的比较随意,有吹黑哨的嫌疑。看一下就好了。

测试对象

\(n=1000\),测试 \(n\times n\) 矩阵乘 \(n\times n\)atcoder::modint998244353 的矩阵乘法速度。

矩阵数字生成:mt19937 rng{1}

正确结果矩阵元素异或和:6597111

编译选项: g++ $< -o $@ -O2 -std=c++14 -static

编译器版本:g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0

系统:Linux LAPTOP-VTBPQCQP 5.15.153.1-microsoft-standard-WSL2 #1 SMP Fri Mar 29 23:14:13 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

测试代码

#!/bin/env pypy3
import os
src_head = """#include <bits/stdc++.h>
#include "atcoder/modint"
using namespace std;
using LL = long long;
using mint = atcoder::modint998244353;
int n = 1000;
"""
src_tail = """
int main() { mt19937 rng{1}; init(); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { a[i][j] = rng(); } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { b[i][j] = rng(); } } mul(); int ret = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { ret ^= c[i][j].val(); } } cout << ret << endl; return 0; } """
def test(code, name, order):
    with open("main.cpp", "w") as file:
        print(src_head, file=file)
        print(code.format(order[0], order[1], order[2]), file=file)
        print(src_tail, file=file)
    print(name, order)
    os.system("./test.sh")
array_array = """
mint a[1000][1000], b[1000][1000], c[1000][1000];
void init() {{ }}
void mul() {{
for (int {0} = 0; {0} < n; {0}++) 
for (int {1} = 0; {1} < n; {1}++) 
for (int {2} = 0; {2} < n; {2}++) 
    c[i][j] += a[i][k] * b[k][j];
}}"""

for o in ["ijk", "ikj", "jik", "jki", "kij", "kji"]:
    test(array_array, "mint[][]", o)
#!/bin/bash
make main || exit 1
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'

纯原生数组

代码

mint a[1000][1000], b[1000][1000], c[1010][1010];
void mul() {
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < n; j++) {
      for (int k = 0; k < n; k++) {
        c[i][j] += a[i][k] * b[k][j];
      }
    }
  }
}

结果

mint[][] ijk
user    0m1.901s
user    0m1.881s
user    0m1.953s
user    0m1.924s
user    0m1.921s
mint[][] ikj
user    0m1.405s
user    0m1.418s
user    0m1.417s
user    0m1.420s
user    0m1.413s
mint[][] jik
user    0m1.851s
user    0m1.843s
user    0m1.838s
user    0m1.851s
user    0m1.809s
mint[][] jki
user    0m2.549s
user    0m2.796s
user    0m3.098s
user    0m2.451s
user    0m2.441s
mint[][] kij
user    0m1.448s
user    0m1.461s
user    0m1.432s
user    0m1.438s
user    0m1.441s
mint[][] kji
user    0m2.368s
user    0m2.301s
user    0m2.312s
user    0m2.343s
user    0m2.393s

结论

纯原生数组,\(i, k, j\)\(k, i, j\) 的顺序跑的最快。

vector 数组

代码

void mul() {
  for (int i = 0; i < n; i++) c[i].resize(n);
  for (int i = 0; i < n; i++) {
    for (int k = 0; k < n; k++) {
      for (int j = 0; j < n; j++) {
        c[i][j] += a[i][k] * b[k][j];
      }
    }
  }
}

结果

vector[] ijk
user    0m2.241s
user    0m2.186s
user    0m2.173s
user    0m2.232s
user    0m2.137s
vector[] ikj
user    0m1.314s
user    0m1.271s
user    0m1.273s
user    0m1.279s
user    0m1.272s
vector[] jik
user    0m2.130s
user    0m2.144s
user    0m2.122s
user    0m2.121s
user    0m2.110s
vector[] jki
user    0m3.463s
user    0m3.498s
user    0m4.141s
user    0m4.225s
user    0m3.928s
vector[] kij
user    0m1.305s
user    0m1.275s
user    0m1.264s
user    0m1.287s
user    0m1.256s
vector[] kji
user    0m3.164s
user    0m3.174s
user    0m3.231s
user    0m3.184s
user    0m3.677s

结论

vector 数组比纯数组快,\(i, k, j\)\(k, i, j\) 的顺序跑的最快。

valarray 数组

代码

valarray<mint> a[1000], b[1000], c[1000];
void init() { for (int i = 0; i < n; i++) a[i].resize(n), b[i].resize(n), c[i].resize(n); }
void mul() {
for (int {0} = 0; {0} < n; {0}++) 
for (int {1} = 0; {1} < n; {1}++) 
    c[i] += a[i][k] * b[k];
}

结果

valarray[] ik
user    0m1.325s
user    0m1.319s
user    0m1.322s
user    0m1.332s
user    0m1.332s
valarray[] ki
user    0m1.331s
user    0m1.310s
user    0m1.323s
user    0m1.344s
user    0m1.323s

vector 套 valarray

代码

vector<valarray<mint>> a, b, c;
void init() { a.resize(n, valarray<mint>(n)); c.resize(n, valarray<mint>(n)); b.resize(n, valarray<mint>(n)); }
void mul() {
for (int {0} = 0; {0} < n; {0}++) 
for (int {1} = 0; {1} < n; {1}++) 
    c[i] += a[i][k] * b[k];
}

结果

vector<valarray> ik
user    0m1.261s
user    0m1.290s
user    0m1.273s
user    0m1.276s
user    0m1.270s
vector<valarray> ki
user    0m1.334s
user    0m1.321s
user    0m1.311s
user    0m1.310s
user    0m1.312s

vector 套 vector

代码

vector<vector<mint>> a, b, c;
void init() {{ a.resize(n,vector<mint>(n)); c.resize(n,vector<mint>(n)); b.resize(n,vector<mint>(n)); }}
void mul() {{
for (int {0} = 0; {0} < n; {0}++) 
for (int {1} = 0; {1} < n; {1}++) 
for (int {2} = 0; {2} < n; {2}++)
    c[i][j] += a[i][k] * b[k][j];
}}

结果

vector<vector> ijk
user    0m2.218s
user    0m2.170s
user    0m2.273s
user    0m2.206s
user    0m2.226s
vector<vector> ikj
user    0m1.262s
user    0m1.230s
user    0m1.254s
user    0m1.245s
user    0m1.240s
vector<vector> jik
user    0m2.156s
user    0m2.143s
user    0m2.167s
user    0m2.158s
user    0m2.175s
vector<vector> jki
user    0m3.277s
user    0m3.171s
user    0m3.080s
user    0m3.225s
user    0m3.149s
vector<vector> kij
user    0m1.273s
user    0m1.263s
user    0m1.274s
user    0m1.287s
user    0m1.288s
vector<vector> kji
user    0m3.351s
user    0m3.021s
user    0m2.867s
user    0m2.879s
user    0m2.824s

原生数组展平(暴力实现)

代码

mint a[1000000], b[1000000], c[1000000];
void init() { }
void mul() {
for (int {0} = 0; {0} < n; {0}++) 
for (int {1} = 0; {1} < n; {1}++) 
for (int {2} = 0; {2} < n; {2}++)
    c[i * n + j] += a[i * n + k] * b[k * n + j];
}

结果

mint[i * n + j] ijk
user    0m3.086s
user    0m3.122s
user    0m3.082s
user    0m3.052s
user    0m3.258s
mint[i * n + j] ikj
user    0m2.122s
user    0m2.106s
user    0m2.112s
user    0m2.098s
user    0m2.092s
mint[i * n + j] jik
user    0m3.084s
user    0m3.128s
user    0m3.088s
user    0m3.122s
user    0m3.076s
mint[i * n + j] jki
user    0m3.519s
user    0m3.773s
user    0m3.571s
user    0m3.515s
user    0m3.487s
mint[i * n + j] kij
user    0m2.160s
user    0m2.130s
user    0m2.179s
user    0m2.188s
user    0m2.180s
mint[i * n + j] kji
user    0m3.174s
user    0m3.202s
user    0m3.217s
user    0m3.306s
user    0m3.330s

原生数组展平(针对性优化)

ijk

user    0m2.769s
user    0m2.688s
user    0m2.701s
user    0m2.678s
user    0m2.737s

ikj

user    0m1.281s
user    0m1.271s
user    0m1.256s
user    0m1.273s
user    0m1.271s

kij

user    0m1.869s
user    0m1.847s
user    0m1.875s
user    0m1.826s
user    0m1.788s

其它不测了

vector 展平

ikj

user    0m1.358s
user    0m1.377s
user    0m1.330s
user    0m1.323s
user    0m1.405s

valarray 展平(slice)

ikj

注意 slice_array 没有重载数乘,很自闭。

user    0m1.420s
user    0m1.404s
user    0m1.486s
user    0m1.472s
user    0m1.406s

总结

\(i, k, j\) 实至名归。这之下,vector<valarray<mint>>vector<vector<mint>>、原生数组展平(但必须写指针加加减减的形式)都比较好。

不知道为什么能测出这样的结论,说好的 vector<vector<mint>> 的储存不连续呢?那写矩阵题到底应该将矩阵封装成什么?

作为经验丰富的想象学竞赛选手,读者自行想象不难。

posted @ 2024-07-26 20:19  caijianhong  阅读(10)  评论(0编辑  收藏  举报