Skip to content

Commit 05bc8fe

Browse files
Revert "follow up to pytorch#147548, fix regression on MI300 (pytorch#147878)"
This reverts commit cc444e7. Reverted pytorch#147878 on behalf of https://github.com/wdvr due to temporary reverting to revert an older one in the stack ([comment](pytorch#147878 (comment)))
1 parent 2df9a8d commit 05bc8fe

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,11 +1537,8 @@ void scaled_gemm(
15371537
// rowwise isn't supported using cublaslt or older hipblaslt
15381538
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
15391539
#endif
1540-
// do not remove {}, this scope corresponds to the else condition in the USE_ROCM section above
1541-
{
1542-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1543-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
1544-
}
1540+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1541+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
15451542
if (result_scale_ptr != nullptr) {
15461543
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
15471544
}

0 commit comments

Comments
 (0)