Skip to content

Commit 2922b9f

Browse files
naromero77amdpytorchmergebot
authored andcommitted
[ROCm] Fix ADDMM hipBLASLt regression (pytorch#138267)
Fixes pytorch#138067 A partial reversion of this PR: pytorch#137604 The breakage is on AMD GPUs that do not fully support hipBLASLt, e.g. gfx1100 Pull Request resolved: pytorch#138267 Approved by: https://github.com/eqy, https://github.com/jeffdaily
1 parent ad93357 commit 2922b9f

File tree

2 files changed

+7
-40
lines changed

2 files changed

+7
-40
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ static bool isSupportedHipLtROCmArch(int index) {
202202
return true;
203203
}
204204
}
205-
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
206205
return false;
207206
}
208207
#endif
@@ -265,7 +264,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
265264
IntArrayRef mat2_sizes = mat2.sizes();
266265
IntArrayRef self__sizes;
267266
bool useLtInterface = false;
267+
#if defined(USE_ROCM)
268+
// When hipBLASLt is not supported on the architecture,
269+
// disable_addmm_cuda_lt will always be to set to true
270+
static bool disable_addmm_cuda_lt =
271+
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
272+
#else
268273
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
274+
#endif
269275
at::ScalarType scalar_type = self.scalar_type();
270276
c10::MaybeOwned<Tensor> self_;
271277
if (&result != &self) {
@@ -283,7 +289,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
283289
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
284290
self.is_contiguous() && result.is_contiguous() &&
285291
#ifdef USE_ROCM
286-
isSupportedHipLtROCmArch(self.device().index()) &&
287292
(scalar_type == at::ScalarType::Float ||
288293
scalar_type == at::ScalarType::Half ||
289294
scalar_type == at::ScalarType::BFloat16) &&

test/test_linalg.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5231,44 +5231,6 @@ def test_corner_cases_of_cublasltmatmul(self, device, dtype):
52315231
m2 = torch.randn(16, 131071, device=device).to(dtype)
52325232
torch.nn.functional.linear(m1, m2, M)
52335233

5234-
@onlyCUDA
5235-
@skipCUDAIfNotRocm
5236-
@dtypes(*floating_types_and(torch.bfloat16, torch.half))
5237-
def test_hipblaslt_corner_cases_rocm(self, device, dtype):
5238-
if dtype == torch.double:
5239-
raise unittest.SkipTest("hipblasLt doesn't support doubles yet")
5240-
5241-
# enable hipblaslt path via env variable.
5242-
import os
5243-
DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT"
5244-
prev_val = os.getenv(DISABLE_ADDMM_HIP_LT)
5245-
try:
5246-
os.environ[DISABLE_ADDMM_HIP_LT] = "0"
5247-
# common case
5248-
M = torch.randn(128, device=device, dtype=dtype)
5249-
m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5250-
m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5251-
out1 = torch.nn.functional.linear(m1, m2, M)
5252-
M_cpu = M.to('cpu')
5253-
m1_cpu = m1.to('cpu')
5254-
m2_cpu = m2.to('cpu')
5255-
out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu)
5256-
self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2))
5257-
5258-
# common case without bias
5259-
m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5260-
m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5261-
out2 = torch.nn.functional.linear(m1, m2, bias=None)
5262-
m1_cpu = m1.to('cpu')
5263-
m2_cpu = m2.to('cpu')
5264-
out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None)
5265-
self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2))
5266-
finally:
5267-
if prev_val is None:
5268-
del os.environ[DISABLE_ADDMM_HIP_LT]
5269-
else:
5270-
os.environ[DISABLE_ADDMM_HIP_LT] = prev_val
5271-
52725234
@dtypesIfCUDA(*floating_and_complex_types_and(
52735235
torch.half,
52745236
*[torch.bfloat16] if SM53OrLater else []

0 commit comments

Comments
 (0)