cuBLAS 实现 triangular_solve 记录
triangular_solve 算子简介
目标:求解上三角或者下三角作为系数方阵的线性系统,即求解 \(AX=B\) 或者 \(XA=B\),求解具有唯一解的 X。
接口输出参数:
- A:形状 [*, M, M],A 是一个方阵。其中 * 表示 batch_size,或者直接没有也行。
- B:形状 [*, M, K]。
- left_side:指定系数方阵 A 在求解矩阵 X 的左侧还是右侧,比如 left_side = true,那么求解 \(AX=B\)
- upper:指定系数方阵是上三角还是下三角。(主对角线的上方还是下方,主对角线是矩阵的左上角到右下角)
- transpose_a:指定是否对系数方阵进行转置。
- unit_diagonal:指定系数方阵对角线上的元素是否都是 1。
输出:
- X:如果 XA = B,即 left_side = True,那么输出 shape 为 [*, M, K]; 如果 AX = B,即 left_side = False,那么输出 shape 为 [*, K, M]。
实现遇到的问题
使用 cuBLAS 库提供的 cublasStrsm 或者 cublasStrsmBatched 接口进行求解,不过 cuBLAS 的接口要求传入的数据是列主序(col major),而上层传入进来的数据是行主序(row major)的,为了解决这个问题,可以对 \(AX=B\) 的两边同时应用转置操作,可以避免在传入 cublasStrsmBatched 接口前,需要手动对数据进行转换。cuBLAS 的其他接口也可以参考这个做法,比如矩阵乘法,公式为 \(AB=C\),同样可以使用转置,避免调用前的数据转换操作。
以 triangular_solve 算子为例子,具体做法如下,首先对公式两边转置。
因此,用户传进来的矩阵可以避免数据转换,因为上层传进来的行主序存储的矩阵,在 cuBLAS 的列主序看来恰好是接口需要的。举个例子,上层传入的矩阵 \(A\) 在内存中为 [1, 1, 1, 0, 2, 1, 0, 0, -1]
,在不同视角看这段内存,数学上的含义不同。这段内存,对于上层的用户来说,对应的是矩阵 \(A\),而对于 cuBLAS 来说却是 \(A^T\)。因此,为了将 cuBLAS 的接口使用起来,可以使用转置后的公式 \(X^T A^T = B^T\)。对于矩阵 \(B\) 是同样的道理,至于输出的结果矩阵 \(X^T\),保持内存布局不变的话,从上层用户看来就是经过了转置的矩阵 \(X\),于是这种方法还不需要对输出的矩阵进行转置操作。
# 行主序视角的矩阵 A
[[1, 1, 1],
[0, 2, 1],
[0, 0, -1]]
# 列主序视角的矩阵,在数学上为 A^T
[[1, 0, 0],
[1, 2, 0],
[1, 1, -1]]
使用了上述方法之后,传入接口的参数需要稍微注意。
- left_side:需要对参数取反。
- upper:需要对参数取反。
- transpose_a:保持不变。
- unit_diagonal:保持不变。
upper 取反的原因如下,假设我们有如下线性方程组。
1 * x_1 + 1 * x_2 + 1 * x_3 = 0
2 * x_2 + 1 * x_3 = -9
-1 * x_3 = 5
那么矩阵 \(A\),\(B\),\(X\) 分别如下,假设矩阵 \(A\) 取上三角进行计算:
# A: 3*3, X: 3*1, B: 3*1
[[1, 1, 1], [[x_1], [[0],
[0, 2, 1], * [x_2], = [-9],
[0, 0, -1]] [x_3]] [5]]
对两边进行了转置操作之后,得到如下:
[[1, 0, 0],
[[x_1, x_2, x_3]] * [1, 2, 0], = [[0, -9, 5]]
[1, 1, -1]]
观察如上转置后的矩阵,可以发现,我们需要的是矩阵 \(A\) 的下三角部分。
代码
代码来自 cuBLAS 的示例,并对其做了一定的修改。使用如下命令进行编译和运行:
nvcc -o triangular_solve triangular_solve.cu -lcublas && ./triangular_solve
/*
* Copyright 2020 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "cublas_utils.h"
using data_type = double;
int main(int argc, char *argv[]) {
double *darr = new double[5]{1.0, 2.0, 3.0, 4.0, 5.0};
cublasHandle_t cublasH = NULL;
cudaStream_t stream = NULL;
cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
cublasOperation_t transa = CUBLAS_OP_N;
cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
const int m = 3;
const int n = 2;
const int lda = m;
const int ldb = n;
const int batch_size = 2;
const std::vector<data_type> alpha(n * batch_size, data_type(1));
const std::vector<data_type> A = {1.0, 1.0, 1.0,
0.0, 2.0, 1.0,
0.0, 0.0, -1.0,
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0};
std::vector<data_type> B = {0.0, 0.0,
-9.0, -9.0,
5.0, 5.0,
1.0, 2.0,
3.0, 4.0,
5.0, 6.0};
std::vector<data_type> X(B.size(), data_type(0));
std::vector<data_type *> Aarray(batch_size, nullptr);
std::vector<data_type *> Barray(batch_size, nullptr);
data_type **d_Aarray = nullptr;
data_type **d_Barray = nullptr;
/* step 1: create cublas handle, bind a stream */
CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(cublasH, stream));
/* step 2: copy data to device */
for (int i = 0; i < batch_size; ++i) {
CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&Aarray[i]), sizeof(data_type) * m * m));
CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&Barray[i]), sizeof(data_type) * m * n));
}
CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&d_Aarray), sizeof(data_type *) * Aarray.size()));
CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&d_Barray), sizeof(data_type *) * Barray.size()));
for (int i = 0; i < batch_size; ++i) {
CUDA_CHECK(cudaMemcpyAsync(Aarray[i], A.data() + i * m * m,
sizeof(data_type) * m * m, cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(Barray[i], B.data() + i * m * n,
sizeof(data_type) * m * n, cudaMemcpyHostToDevice, stream));
}
CUDA_CHECK(cudaMemcpyAsync(d_Aarray, Aarray.data(), sizeof(data_type) * Aarray.size(),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(d_Barray, Barray.data(), sizeof(data_type) * Barray.size(),
cudaMemcpyHostToDevice, stream));
/* step 3: compute */
CUBLAS_CHECK(cublasDtrsmBatched(cublasH, side, uplo, transa, diag, n, m, alpha.data(), d_Aarray,
lda, d_Barray, ldb, batch_size));
/* step 4: copy data to host */
for (int i = 0; i < batch_size; ++i) {
CUDA_CHECK(cudaMemcpyAsync(X.data() + i * m * n, Barray[i], sizeof(data_type) * m * n,
cudaMemcpyDeviceToHost, stream));
}
CUDA_CHECK(cudaStreamSynchronize(stream));
/* step 5: pring the result */
for (int i = 0; i < batch_size; ++i) {
printf("batch = %d\n", i);
print_vector(m * n, X.data() + i * m * n);
printf("==========\n");
}
/* free resources */
for (int i = 0; i < batch_size; ++i) {
CUDA_CHECK(cudaFree(Aarray[i]));
CUDA_CHECK(cudaFree(Barray[i]));
}
CUDA_CHECK(cudaFree(d_Aarray));
CUDA_CHECK(cudaFree(d_Barray));
CUBLAS_CHECK(cublasDestroy(cublasH));
CUDA_CHECK(cudaStreamDestroy(stream));
CUDA_CHECK(cudaDeviceReset());
return EXIT_SUCCESS;
}