cuda - Does cuBLAS support mixed precision matrix multiplication in the form C[f32] = A[bf16] * B[f32]? - Stack Overflow

I'm concerning mixed precision in deep learning LLM. The intermediates are mostly F32 and weights

I'm concerning mixed precision in deep learning LLM. The intermediates are mostly F32 and weights could be any other type like BF16, F16, even quantized type Q8_0, Q4_0. it would be much useful if cuBLAS support inputs with different data type.

I learned about cublas mixed precision api from official doc and I wonder if cublasGemmEx() support mixed data type as inputs, like in C[f32] = A[f16] * B[f32]. I wrote a simple program to test it and it failed with error as follwoing:

yliu@servere5:~$ ./mix
Device: NVIDIA GeForce RTX 4060 Ti
Compute capability: 8.9

cuBLAS error at m.cu 81: 15

code below compiled as nvcc -o mix m.cu -lcublas -arch=native

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <stdio.h>

#define CUDA_CHECK(call) \
do { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
        printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
        exit(EXIT_FAILURE); \
    } \
} while(0)

#define CUBLAS_CHECK(call) \
do { \
    cublasStatus_t status = call; \
    if (status != CUBLAS_STATUS_SUCCESS) { \
        printf("cuBLAS error at %s %d: %d\n", __FILE__, __LINE__, status); \
        exit(EXIT_FAILURE); \
    } \
} while(0)

void printMatrix(const char* name, float* matrix, int rows, int cols) {
    printf("%s:\n", name);
    for (int i = 0; i < rows; ++i) {
        for (int j = 0; j < cols; ++j) {
            printf("%.2f ", matrix[i * cols + j]);
        }
        printf("\n");
    }
    printf("\n");
}

int main() {
    int deviceId;
    cudaDeviceProp deviceProp;

    CUDA_CHECK(cudaGetDevice(&deviceId));
    CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, deviceId));

    printf("Device: %s\n", deviceProp.name);
    printf("Compute capability: %d.%d\n\n", deviceProp.major, deviceProp.minor);

    const int m = 4;
    const int n = 4;
    const int k = 4;

    half* h_A = new half[m * k];
    float* h_B = new float[k * n];
    float* h_C = new float[m * n];

    for (int i = 0; i < m * k; i++) {
        h_A[i] = __float2half(static_cast<float>(i % 10));
    }

    for (int i = 0; i < k * n; i++) {
        h_B[i] = static_cast<float>(i % 10);
    }

    half* d_A;
    float* d_B;
    float* d_C;

    CUDA_CHECK(cudaMalloc(&d_A, m * k * sizeof(half)));
    CUDA_CHECK(cudaMalloc(&d_B, k * n * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&d_C, m * n * sizeof(float)));

    CUDA_CHECK(cudaMemcpy(d_A, h_A, m * k * sizeof(half), cudaMemcpyHostToDevice));
    CUDA_CHECK(cudaMemcpy(d_B, h_B, k * n * sizeof(float), cudaMemcpyHostToDevice));

    cublasHandle_t handle;
    CUBLAS_CHECK(cublasCreate(&handle));


    const float alpha = 1.0f;
    const float beta = 0.0f;

    // Perform mixed precision matrix multiplication: C = alpha * A * B + beta * C
    // Using cublasGemmEx for mixed precision computation
    CUBLAS_CHECK(cublasGemmEx(
        handle,
        CUBLAS_OP_N,                // op_A: no transpose
        CUBLAS_OP_N,                // op_B: no transpose
        m, n, k,                    // m, n, k
        &alpha,                     // alpha
        d_A, CUDA_R_16F, m,         // A, A type, lda
        d_B, CUDA_R_32F, k,         // B, B type, ldb
        &beta,                      // beta
        d_C, CUDA_R_32F, m,         // C, C type, ldc
        CUDA_R_32F,                 // Compute type (FP32)
        CUBLAS_GEMM_DEFAULT         // Algorithm to use
    ));

    CUDA_CHECK(cudaMemcpy(h_C, d_C, m * n * sizeof(float), cudaMemcpyDeviceToHost));

    float* h_A_float = new float[m * k];
    for (int i = 0; i < m * k; i++) {
        h_A_float[i] = __half2float(h_A[i]);
    }

    printMatrix("Matrix A (FP16 converted to FP32 for display)", h_A_float, m, k);
    printMatrix("Matrix B (FP32)", h_B, k, n);
    printMatrix("Matrix C = A * B (FP32)", h_C, m, n);

    delete[] h_A;
    delete[] h_B;
    delete[] h_C;
    delete[] h_A_float;

    CUDA_CHECK(cudaFree(d_A));
    CUDA_CHECK(cudaFree(d_B));
    CUDA_CHECK(cudaFree(d_C));
    CUBLAS_CHECK(cublasDestroy(handle));

    return 0;
}

I tried the form like C[f32] = A[f16] * B[f16] and it worked without error. My question is, the precision API just support the inputs as the same data type? Or I just missed something?

Thanks in advance.


as Robert Crovella said:"the Atype and Btype have a single column; they are expected to always match."

This is evidenced by my simple program.

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1745091110a4610699.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信