Skip to content

Commit 97c4c10

Browse files
Fix for triangular_solve BEF executable unit test failure on ROCm.
1 parent aae8198 commit 97c4c10

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

backends/gpu/lib/kernels/blas_kernels.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ static Error BlasTrsmBatch(
222222
const void** a_array = const_cast<const void**>(b_array + batchCount);
223223

224224
auto side_mode = wrapper::BlasSideMode::FromOpaqueValue(*sideMode);
225-
int32_t a_num_elements = side_mode == CUBLAS_SIDE_LEFT ? m * m : n * n;
225+
int32_t a_num_elements = 0;
226+
if (platform == wrapper::Platform::CUDA)
227+
a_num_elements = side_mode == CUBLAS_SIDE_LEFT ? m * m : n * n;
228+
else
229+
a_num_elements = side_mode == rocblas_side_left ? m * m : n * n;
226230
ptrdiff_t a_batch_stride_bytes = *data_type_size_bytes * a_num_elements;
227231
ptrdiff_t b_batch_stride_bytes = *data_type_size_bytes * m * n;
228232
const char* a_ptr = static_cast<const char*>(A.pointer().raw(platform));

0 commit comments

Comments
 (0)