QUDA Interfaces to cuBLAS - lattice/quda GitHub Wiki
A powerful feature in the CUDA SDK is the strided, batched BLAS routines. These routines allow users to supply arrays of multiple matrices that will be processed in a batch fashion, leading to excellent GPU utilisation. The interface offers two wrappers: blasLUInvQuda
and blasGEMMQuda
, which accept host or device data and pass to cublas(C/Z)getrfBatched
and cublas(H/S/C/D/Z)gemmStridedBatched
respectively. Common to both wrapper calls is a QudaBLASParam structure, defined in include/quda.h
that holds entries for all required parameters for either call.
typedef struct QudaBLASParam_s {
size_t struct_size; /**< Size of this struct in bytes. Used to ensure that the host application and QUDA see the same struct*/
QudaBLASType blas_type; /**< Type of BLAS computation to perfrom */
// GEMM params
QudaBLASOperation trans_a; /**< operation op(A) that is non- or (conj.) transpose. */
QudaBLASOperation trans_b; /**< operation op(B) that is non- or (conj.) transpose. */
int m; /**< number of rows of matrix op(A) and C. */
int n; /**< number of columns of matrix op(B) and C. */
int k; /**< number of columns of op(A) and rows of op(B). */
int lda; /**< leading dimension of two-dimensional array used to store the matrix A. */
int ldb; /**< leading dimension of two-dimensional array used to store matrix B. */
int ldc; /**< leading dimension of two-dimensional array used to store matrix C. */
int a_offset; /**< position of the A array from which begin read/write. */
int b_offset; /**< position of the B array from which begin read/write. */
int c_offset; /**< position of the C array from which begin read/write. */
int a_stride; /**< stride of the A array in strided(batched) mode */
int b_stride; /**< stride of the B array in strided(batched) mode */
int c_stride; /**< stride of the C array in strided(batched) mode */
double_complex alpha; /**< scalar used for multiplication. */
double_complex beta; /**< scalar used for multiplication. If beta==0, C does not have to be a valid input. */
// LU inversion params
int inv_mat_size; /**< The rank of the square matrix in the LU inversion */
// Common params
int batch_count; /**< number of pointers contained in arrayA, arrayB and arrayC. */
QudaBLASDataType data_type; /**< Specifies if using S(C) or D(Z) BLAS type */
QudaBLASDataOrder data_order; /**< Specifies if using Row or Column major */
} QudaBLASParam;
This wrapper routine accepts host or device pointers to arrays of matrices Ainv
and A
where the solution (Ainv
) and problem (A
) matrices are stored. It also takes a QudaBoolean use_native
instructing the interface to call the native BLAS library call (cuBLAS, hipBLAS) or divert to generic host code routines using Eigen
.
/**
* @brief Strided Batched in-place matrix inversion via LU
* @param[in] Ainv The array containing the A inverse matrix data
* @param[in] A The array containing the A matrix data
* @param[in] use_native Boolean to use either the native or generic version
* @param[in] param The data defining the problem execution.
*/
void blasLUInvQuda(void *Ainv, void *A, QudaBoolean use_native, QudaBLASParam *param);
Data transfer is performed automatically by the wrapper. QUDA will then construct appropriate device arrays and parameters for native BLAS call and return the result to the host side Ainv
We also allow for finer grained control over this wrapper via BatchInvertMatrix
.
/**
@brief Batch inversion the matrix field using an LU decomposition method.
@param[out] Ainv Matrix field containing the inverse matrices
@param[in] A Matrix field containing the input matrices
@param[in] n Dimension each matrix
@param[in] batch Problem batch size
@param[in] precision Precision of the input/output data
@param[in] Location of the input/output data
@return Number of flops done in this computation */
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision,
QudaFieldLocation location);
Here, one may pass device or host pointers, indicating this via the location
argument. Array lengths are deduced from the batch
value.
This wrapper routine accepts host or device pointers to arrays of matrices A
, B
, and C
, where the output of the operation
C = \alpha * op(A) * op(B) + \beta * C
is stored in place in C
. It also takes a QudaBoolean use_native
instructing the interface to call the native BLAS library call (cuBLAS, hipBLAS) or divert to generic host code routines using Eigen
, and a QudaBLASParam *param
defining the GEMM to be performed.
/**
* @brief Strided Batched GEMM
* @param[in] arrayA The array containing the A matrix data
* @param[in] arrayB The array containing the B matrix data
* @param[in] arrayC The array containing the C matrix data
* @param[in] native Boolean to use either the native or generic version
* @param[in] param The data defining the problem execution.
*/
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean native, QudaBLASParam *param);
An important deviation from the standard cuBLAS
and GEMM call arguments are the form of the strides,
int a_stride; /**< stride of the A array in strided(batched) mode */
int b_stride; /**< stride of the B array in strided(batched) mode */
int c_stride; /**< stride of the C array in strided(batched) mode */
in that rather than being the number of data elements in the matrix, the strides are defined as the number of matrices to stride over. This is because the QUDA GEMM wrappers will accept data in row order and then infer the proper transposition/conjugation operations, leading dimensions, strides, etc, so that the row order data may passed to the column order compute of the native architecture with as few argument adjustments as possible. Another important point to note is that batches
is to be understood at the number of matrices in the passed arrays. This is done to infer array lengths for the device/host transfers. In the wrapper, the actual number of GEMM operations will be inferred from the matrix stride values.
For example, if one has
a_stride = 1;
b_stride = 1;
c_stride = 1;
batches = 128
then 128 batched matrix GEMMs are computed from the arrays A
, B
, C
, each of length 128 matrices. On the other hand, if one passes
a_stride = 1;
b_stride = 2;
c_stride = 1;
batches = 128
Then 128/2 = 64 GEMMS are computed using
A_{0} * B_{0} = C_{0},
A_{1} * B_{2} = C_{1},
A_{2} * B_{4} = C_{2},
... ,
A_{63} * B_{126} = C_{63}`
Further, one can set a stride to 0 (a_stride = 0 here) and compute
A_{0} * B_{0} = C_{0},
A_{0} * B_{1} = C_{1},
A_{0} * B_{2} = C_{2},
... ,
A_{0} * B_{127} = C_{127}`
Similar to BatchInvertMatrix
, we offer a finer grained wrapper to the native GEMM
/**
@brief Strided Batch GEMM. This function performs N GEMM type operations in a
strided batched fashion. If the user passes
stride<A,B,C> = -1
it deduces the strides for the A, B, and C arrays from the matrix dimensions,
leading dims, etc, and will behave identically to the batched GEMM.
If any of the stride<A,B,C> values passed in the parameter structure are
greater than or equal to 0, the routine accepts the user's values instead.
Example: If the user passes
a_stride = 0
the routine will use only the first matrix in the A array and compute
C_{n} <- a * A_{0} * B_{n} + b * C_{n}
where n is the batch index.
@param[in] A Matrix field containing the A input matrices
@param[in] B Matrix field containing the B input matrices
@param[in/out] C Matrix field containing the result, and matrix to be added
@param[in] QudaBLASparam Parameter structure defining the GEMM type
@param[in] Location of the input/output data
@return Number of flops done in this computation
*/
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location);