cuBLAS 实现 triangular_solve 记录

triangular_solve 算子简介

目标:求解上三角或者下三角作为系数方阵的线性系统,即求解 \(AX=B\) 或者 \(XA=B\),求解具有唯一解的 X。

接口输出参数:

  • A:形状 [*, M, M],A 是一个方阵。其中 * 表示 batch_size,或者直接没有也行。
  • B:形状 [*, M, K]。
  • left_side:指定系数方阵 A 在求解矩阵 X 的左侧还是右侧,比如 left_side = true,那么求解 \(AX=B\)
  • upper:指定系数方阵是上三角还是下三角。(主对角线的上方还是下方,主对角线是矩阵的左上角到右下角)
  • transpose_a:指定是否对系数方阵进行转置。
  • unit_diagonal:指定系数方阵对角线上的元素是否都是 1。

输出:

  • X:如果 XA = B,即 left_side = True,那么输出 shape 为 [*, M, K]; 如果 AX = B,即 left_side = False,那么输出 shape 为 [*, K, M]。

实现遇到的问题

使用 cuBLAS 库提供的 cublasStrsm 或者 cublasStrsmBatched 接口进行求解,不过 cuBLAS 的接口要求传入的数据是列主序(col major),而上层传入进来的数据是行主序(row major)的,为了解决这个问题,可以对 \(AX=B\) 的两边同时应用转置操作,可以避免在传入 cublasStrsmBatched 接口前,需要手动对数据进行转换。cuBLAS 的其他接口也可以参考这个做法,比如矩阵乘法,公式为 \(AB=C\),同样可以使用转置,避免调用前的数据转换操作。

以 triangular_solve 算子为例子,具体做法如下,首先对公式两边转置。

\[AX=B \]

\[X^T A^T = B^T \]

因此,用户传进来的矩阵可以避免数据转换,因为上层传进来的行主序存储的矩阵,在 cuBLAS 的列主序看来恰好是接口需要的。举个例子,上层传入的矩阵 \(A\) 在内存中为 [1, 1, 1, 0, 2, 1, 0, 0, -1],在不同视角看这段内存,数学上的含义不同。这段内存,对于上层的用户来说,对应的是矩阵 \(A\),而对于 cuBLAS 来说却是 \(A^T\)。因此,为了将 cuBLAS 的接口使用起来,可以使用转置后的公式 \(X^T A^T = B^T\)。对于矩阵 \(B\) 是同样的道理,至于输出的结果矩阵 \(X^T\),保持内存布局不变的话,从上层用户看来就是经过了转置的矩阵 \(X\),于是这种方法还不需要对输出的矩阵进行转置操作。

# 行主序视角的矩阵 A
[[1, 1, 1],
 [0, 2, 1],
 [0, 0, -1]]

# 列主序视角的矩阵,在数学上为 A^T
[[1, 0, 0],
 [1, 2, 0],
 [1, 1, -1]]

使用了上述方法之后,传入接口的参数需要稍微注意。

  • left_side:需要对参数取反。
  • upper:需要对参数取反。
  • transpose_a:保持不变。
  • unit_diagonal:保持不变。

upper 取反的原因如下,假设我们有如下线性方程组。

1 * x_1 + 1 * x_2 + 1 * x_3 = 0
          2 * x_2 + 1 * x_3 = -9
                   -1 * x_3 = 5

那么矩阵 \(A\)\(B\)\(X\) 分别如下,假设矩阵 \(A\)上三角进行计算:

# A: 3*3, X: 3*1, B: 3*1
[[1, 1, 1],    [[x_1],     [[0],
 [0, 2, 1],  *  [x_2],   =  [-9],
 [0, 0, -1]]    [x_3]]      [5]]

对两边进行了转置操作之后,得到如下:

                    [[1, 0, 0],
[[x_1, x_2, x_3]] *  [1, 2, 0],   = [[0, -9, 5]]
                     [1, 1, -1]]

观察如上转置后的矩阵,可以发现,我们需要的是矩阵 \(A\)下三角部分。

代码

代码来自 cuBLAS 的示例,并对其做了一定的修改。使用如下命令进行编译和运行:

nvcc -o triangular_solve triangular_solve.cu -lcublas && ./triangular_solve 
/*
 * Copyright 2020 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

#include <cstdio>
#include <cstdlib>
#include <vector>

#include <cublas_v2.h>
#include <cuda_runtime.h>

#include "cublas_utils.h"

using data_type = double;

int main(int argc, char *argv[]) {
    double *darr = new double[5]{1.0, 2.0, 3.0, 4.0, 5.0};

    cublasHandle_t cublasH = NULL;
    cudaStream_t stream = NULL;
    cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
    cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
    cublasOperation_t transa = CUBLAS_OP_N;
    cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;

    const int m = 3;
    const int n = 2;
    const int lda = m;
    const int ldb = n;
    const int batch_size = 2;
    const std::vector<data_type> alpha(n * batch_size, data_type(1));
    const std::vector<data_type> A = {1.0, 1.0, 1.0, 
                                      0.0, 2.0, 1.0, 
                                      0.0, 0.0, -1.0,
                                      1.0, 0.0, 0.0,
                                      0.0, 1.0, 0.0,
                                      0.0, 0.0, 1.0};
    std::vector<data_type> B = {0.0, 0.0,
                                -9.0, -9.0, 
                                5.0, 5.0,
                                1.0, 2.0,
                                3.0, 4.0,
                                5.0, 6.0};
    std::vector<data_type> X(B.size(), data_type(0));
    std::vector<data_type *> Aarray(batch_size, nullptr);
    std::vector<data_type *> Barray(batch_size, nullptr);
    data_type **d_Aarray = nullptr;
    data_type **d_Barray = nullptr;

    /* step 1: create cublas handle, bind a stream */
    CUBLAS_CHECK(cublasCreate(&cublasH));
    CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
    CUBLAS_CHECK(cublasSetStream(cublasH, stream));

    /* step 2: copy data to device */
    for (int i = 0; i < batch_size; ++i) {
        CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&Aarray[i]), sizeof(data_type) * m * m));
        CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&Barray[i]), sizeof(data_type) * m * n));
    }
    CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&d_Aarray), sizeof(data_type *) * Aarray.size()));
    CUDA_CHECK(cudaMalloc(reinterpret_cast<void **>(&d_Barray), sizeof(data_type *) * Barray.size()));

    for (int i = 0; i < batch_size; ++i) {
        CUDA_CHECK(cudaMemcpyAsync(Aarray[i], A.data() + i * m * m, 
                                   sizeof(data_type) * m * m,  cudaMemcpyHostToDevice, stream));
        CUDA_CHECK(cudaMemcpyAsync(Barray[i], B.data() + i * m * n,
                                   sizeof(data_type) * m * n, cudaMemcpyHostToDevice, stream));
    }
    CUDA_CHECK(cudaMemcpyAsync(d_Aarray, Aarray.data(), sizeof(data_type) * Aarray.size(),
                               cudaMemcpyHostToDevice, stream));
    CUDA_CHECK(cudaMemcpyAsync(d_Barray, Barray.data(), sizeof(data_type) * Barray.size(),
                               cudaMemcpyHostToDevice, stream));

    /* step 3: compute */
    CUBLAS_CHECK(cublasDtrsmBatched(cublasH, side, uplo, transa, diag, n, m, alpha.data(), d_Aarray, 
                                    lda, d_Barray, ldb, batch_size));

    /* step 4: copy data to host */
    for (int i = 0; i < batch_size; ++i) {
        CUDA_CHECK(cudaMemcpyAsync(X.data() + i * m * n, Barray[i], sizeof(data_type) * m * n, 
                                   cudaMemcpyDeviceToHost, stream));
    }

    CUDA_CHECK(cudaStreamSynchronize(stream));

    /* step 5: pring the result */
    for (int i = 0; i < batch_size; ++i) {
        printf("batch = %d\n", i);
        print_vector(m * n, X.data() + i * m * n);
        printf("==========\n");
    }

    /* free resources */
    for (int i = 0; i < batch_size; ++i) {
        CUDA_CHECK(cudaFree(Aarray[i]));
        CUDA_CHECK(cudaFree(Barray[i]));    
    }
    CUDA_CHECK(cudaFree(d_Aarray));
    CUDA_CHECK(cudaFree(d_Barray));

    CUBLAS_CHECK(cublasDestroy(cublasH));

    CUDA_CHECK(cudaStreamDestroy(stream));

    CUDA_CHECK(cudaDeviceReset());

    return EXIT_SUCCESS;
}
posted @ 2023-02-26 17:37  楷哥  阅读(229)  评论(0编辑  收藏  举报