Skip to content

Commit 46d9347

Browse files
Merge branch 'pytorch:main' into temp-ppc64le-wheel-branch-v8
2 parents 26e9e6f + 103bf64 commit 46d9347

File tree

83 files changed

+2353
-853
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2353
-853
lines changed

.ci/pytorch/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,8 +1173,9 @@ build_xla() {
11731173
apply_patches
11741174
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
11751175
# These functions are defined in .circleci/common.sh in pytorch/xla repo
1176-
retry install_deps_pytorch_xla $XLA_DIR $USE_CACHE
1176+
retry install_pre_deps_pytorch_xla $XLA_DIR $USE_CACHE
11771177
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch:${CMAKE_PREFIX_PATH}" XLA_SANDBOX_BUILD=1 build_torch_xla $XLA_DIR
1178+
retry install_post_deps_pytorch_xla
11781179
assert_git_not_dirty
11791180
}
11801181

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b
1+
760675ad9aa8e7202d4f9f51fe862e8a9bedb713

aten/src/ATen/DLConvertor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ DLDataType getDLDataType(const Tensor& t) {
7171
case ScalarType::Float8_e8m0fnu:
7272
TORCH_CHECK(false, "float8 types are not supported by dlpack");
7373
break;
74+
case ScalarType::Float4_e2m1fn_x2:
75+
TORCH_CHECK(false, "float4 types are not supported by dlpack");
76+
break;
7477
case ScalarType::QInt8:
7578
case ScalarType::QUInt8:
7679
case ScalarType::QInt32:

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,8 @@ void scaled_gemm(
15521552
ScalarType result_dtype,
15531553
bool use_fast_accum,
15541554
bool use_rowwise) {
1555+
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
1556+
// of input arguments to this function.
15551557
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
15561558
const auto computeType = CUBLAS_COMPUTE_32F;
15571559
const auto scaleType = CUDA_R_32F;
@@ -1570,7 +1572,7 @@ void scaled_gemm(
15701572
#else
15711573
// rowwise isn't supported using cublaslt or older hipblaslt
15721574
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
1573-
#endif
1575+
#endif // if defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
15741576
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
15751577
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
15761578
if (result_scale_ptr != nullptr) {
@@ -1583,19 +1585,19 @@ void scaled_gemm(
15831585
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
15841586
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
15851587
}
1586-
#endif
1588+
#endif // ifndef USE_ROCM
15871589
#ifndef USE_ROCM
15881590
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
15891591
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
1590-
#endif
1592+
#endif // ifndef USE_ROCM
15911593
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
15921594
CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
15931595
#ifdef USE_ROCM
15941596
// Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable.
15951597
CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
15961598
#else
15971599
CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
1598-
#endif
1600+
#endif // ifdef USE_ROCM
15991601
CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
16001602
if (bias_ptr) {
16011603
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
@@ -1609,7 +1611,14 @@ void scaled_gemm(
16091611
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
16101612
#else
16111613
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1612-
#endif // CUDA_VERSION >= 12080
1614+
#endif // if CUDA_VERSION >= 12080
1615+
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
1616+
#if CUDA_VERSION >= 12080
1617+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
1618+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
1619+
#else
1620+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above");
1621+
#endif // if CUDA_VERSION >= 12080
16131622
}
16141623

16151624
auto stream = c10::cuda::getCurrentCUDAStream();
@@ -1677,7 +1686,7 @@ void scaled_gemm(
16771686
}
16781687
}
16791688
TORCH_CHECK(found, "could not find valid hipblaslt solution");
1680-
#endif
1689+
#endif // ifndef USE_ROCM
16811690
}
16821691
cublasStatus_t cublasStatus = cublasLtMatmul(
16831692
ltHandle,
@@ -1692,7 +1701,7 @@ void scaled_gemm(
16921701
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
16931702
#else
16941703
nullptr,
1695-
#endif
1704+
#endif // ifdef USE_ROCM
16961705
Cdesc.descriptor(),
16971706
result_ptr,
16981707
Ddesc.descriptor(),
@@ -1725,7 +1734,7 @@ void scaled_gemm(
17251734
" scaleType ",
17261735
scaleType);
17271736
return;
1728-
#endif // CUDA_VERSION >= 11080 || defined(USE_ROCM)
1737+
#endif // if CUDA_VERSION >= 11080 || defined(USE_ROCM)
17291738
TORCH_CHECK(false, "scaled_gemm is only supported for CUDA 11.8 and above");
17301739
}
17311740

aten/src/ATen/cuda/CUDADataType.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
8989
return HIP_R_8F_E4M3_FNUZ;
9090
case c10::ScalarType::Float8_e5m2fnuz:
9191
return HIP_R_8F_E5M2_FNUZ;
92+
#endif
93+
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
94+
case c10::ScalarType::Float4_e2m1fn_x2:
95+
return CUDA_R_4F_E2M1;
9296
#endif
9397
default:
9498
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")

aten/src/ATen/cuda/detail/LazyNVRTC.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **)
158158

159159
CUDA_STUB2(cuModuleLoad, CUmodule*, const char*)
160160
CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *)
161+
CUDA_STUB2(cuFuncSetCacheConfig, CUfunction, CUfunc_cache_enum)
162+
CUDA_STUB3(cuDeviceGetAttribute, int*, CUdevice_attribute_enum, CUdevice)
161163
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *)
162164
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t)
163165
CUDA_STUB2(cuGetErrorString, CUresult, const char **)

aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ namespace at::cuda {
6262
_(cuFuncSetAttribute) \
6363
_(cuFuncGetAttribute) \
6464
_(cuPointerGetAttribute) \
65+
_(cuFuncSetCacheConfig) \
66+
_(cuDeviceGetAttribute) \
67+
6568

6669
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
6770
#define AT_FORALL_NVRTC_EXTENDED(_) \

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
121121
*
122122
* The transpose flags are derived from the layouts of the passed in tensors
123123
*
124+
* If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted
125+
* to their unpacked values to match what cuBLAS expects.
126+
*
124127
* @param mat1 First input matrix
125128
* @param mat2 Second input matrix
126129
* @param c Output matrix (result)
@@ -173,6 +176,14 @@ struct cublasCommonArgs {
173176
result_ld = result->stride(transpose_result ? 0 : 1);
174177
transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n';
175178
transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n';
179+
180+
// cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing
181+
// if the gemm operands are in packed float4
182+
if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) {
183+
k = k * 2;
184+
lda = lda * 2;
185+
ldb = ldb * 2;
186+
}
176187
}
177188

178189
// Matrix members
@@ -980,7 +991,7 @@ enum class ScalingType : std::uint8_t {
980991
* ---------------------------
981992
* Conditions and corresponding Scaling Types:
982993
*
983-
* - If scale tensors are Float8_e8m0fnu:
994+
* - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`:
984995
* - Returns BlockWise (with additional size checks).
985996
*
986997
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
@@ -1001,14 +1012,22 @@ ScalingType get_scaling_type(
10011012
int64_t dim_m,
10021013
int64_t dim_k,
10031014
int64_t dim_n) {
1004-
// Check for BlockWise scaling (FP8_E8M0 types)
1005-
if (scale_a.scalar_type() == scale_b.scalar_type() &&
1006-
scale_a.scalar_type() == at::kFloat8_e8m0fnu) {
1007-
constexpr int64_t BLOCK_SIZE_K = 32;
1015+
// Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types)
1016+
if ((scale_a.scalar_type() == scale_b.scalar_type()) &&
1017+
((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) {
1018+
const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn;
1019+
1020+
// cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements
1021+
// cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements.
1022+
const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32;
1023+
10081024
constexpr int64_t BLOCK_SIZE_MN = 128;
10091025

1026+
// adjust for fp4x2 packing if necessary
1027+
const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k;
1028+
10101029
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
1011-
auto num_k_blocks = ceil_div(dim_k, BLOCK_SIZE_K);
1030+
auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K);
10121031
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
10131032

10141033
// TODO: We might want to enforce some structure on the shapes of the scale
@@ -1149,13 +1168,16 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11491168
mat2.sizes()[1], ") must be divisible by 16");
11501169
// Check types
11511170
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
1152-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
1153-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
1171+
TORCH_CHECK(isFloat8Type(mat1.scalar_type()) || mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat1 to be Float8 or Float4_x2 matrix got ", mat1.scalar_type());
1172+
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
11541173
#ifndef USE_ROCM
11551174
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
11561175
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
11571176
"Multiplication of two Float8_e5m2 matrices is not supported");
11581177
#endif
1178+
if (use_fast_accum) {
1179+
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
1180+
}
11591181
if (bias) {
11601182
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
11611183
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ TORCH_IMPL_FUNC(cat_out_cuda)
507507
kBool,
508508
kBFloat16,
509509
AT_EXPAND(AT_FLOAT8_TYPES),
510-
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
510+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
511+
kFloat4_e2m1fn_x2);
511512
}
512513
} else if (materialized.size() > 1 &&
513514
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
@@ -542,7 +543,9 @@ TORCH_IMPL_FUNC(cat_out_cuda)
542543
kFloat8_e4m3fnuz,
543544
kFloat8_e5m2,
544545
kFloat8_e5m2fnuz,
545-
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
546+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
547+
// TODO(#146647): extend this to other shell dtypes
548+
kFloat4_e2m1fn_x2);
546549
}
547550
} else {
548551
int64_t offset = 0;

aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ Tensor two_four_sgemm(
7575
using LayoutC = LayoutOutput;
7676
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
7777

78-
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
79-
ThreadblockShape,
80-
WarpShape,
81-
ElementC,
82-
AlignmentC,
83-
NumEVTEpilogueStages>;
8478
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
8579
ThreadblockShape,
8680
WarpShape,
@@ -94,7 +88,7 @@ Tensor two_four_sgemm(
9488
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
9589
using BiasTensor =
9690
cutlass::epilogue::threadblock::VisitorColBroadcast<
97-
BiasTileThreadMap,
91+
OutputTileThreadMap,
9892
ElementC,
9993
cute::Stride<cute::_1, cute::_0, int64_t>>;
10094
using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;

0 commit comments

Comments
 (0)