File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -222,7 +222,10 @@ static Error BlasTrsmBatch(
222
222
const void ** a_array = const_cast <const void **>(b_array + batchCount);
223
223
224
224
auto side_mode = wrapper::BlasSideMode::FromOpaqueValue (*sideMode);
225
- int32_t a_num_elements = side_mode == CUBLAS_SIDE_LEFT ? m * m : n * n;
225
+ int32_t a_num_elements = n * n;
226
+ if ((platform == wrapper::Platform::CUDA && side_mode == CUBLAS_SIDE_LEFT) ||
227
+ (platform == wrapper::Platform::ROCm && side_mode == rocblas_side_left))
228
+ a_num_elements = m * m;
226
229
ptrdiff_t a_batch_stride_bytes = *data_type_size_bytes * a_num_elements;
227
230
ptrdiff_t b_batch_stride_bytes = *data_type_size_bytes * m * n;
228
231
const char * a_ptr = static_cast <const char *>(A.pointer ().raw (platform));
You can’t perform that action at this time.
0 commit comments