Skip to content

Commit 6fd98a4

Browse files
npmillernormallytangent
authored andcommitted
[BLAS][HIP] Fix blas support for rocBLAS 4+ (#519)
1 parent dde8190 commit 6fd98a4

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/blas/backends/rocblas/rocblas_level3.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,17 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe
381381
auto a_ = sc.get_mem<rocDataType *>(a_acc);
382382
auto b_ = sc.get_mem<rocDataType *>(b_acc);
383383
rocblas_status err;
384+
#if ROCBLAS_VERSION_MAJOR >= 4
385+
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
386+
get_rocblas_fill_mode(upper_lower),
387+
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
388+
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
389+
#else
384390
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
385391
get_rocblas_fill_mode(upper_lower),
386392
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
387393
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
394+
#endif
388395
});
389396
});
390397
}
@@ -805,10 +812,17 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp
805812
auto a_ = reinterpret_cast<const rocDataType *>(a);
806813
auto b_ = reinterpret_cast<rocDataType *>(b);
807814
rocblas_status err;
815+
#if ROCBLAS_VERSION_MAJOR >= 4
816+
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
817+
get_rocblas_fill_mode(upper_lower),
818+
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
819+
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
820+
#else
808821
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
809822
get_rocblas_fill_mode(upper_lower),
810823
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
811824
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
825+
#endif
812826
});
813827
});
814828

0 commit comments

Comments
 (0)