diff --git a/CMakeLists.txt b/CMakeLists.txt index 8faee83..6a18ec0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,7 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "f58e2df1951b4f99c21be64d4fcd500742a41c59" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h deleted file mode 100644 index f2743bf..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h +++ /dev/null @@ -1,306 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix - multiply-add with the appropriate threadblock-scoped epilogue. - - Note, CUTLASS epilogues universally target row-major outputs. Column-major - outputs are accommodated by exchanging A and B operands and assuming - transposed layouts. Partial specializations here choose - 'device::GemmTransposed' to implement this functionality. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/complex.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "gemm_universal_k.h" -#include "cutlass/gemm/kernel/gemm_universal_streamk.h" -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass/gemm/kernel/default_gemm_complex.h" - -#include "cutlass/layout/permute.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Gather operand A by using an index array - bool GatherA = false, - /// Gather operand B by using an index array - bool GatherB = false, - /// Scatter result D by using an index array - bool ScatterD = false, - /// Permute result D - typename PermuteDLayout = layout::NoPermute, - /// Permute operand A - typename PermuteALayout_ = layout::NoPermute, - /// Permute operand B - typename PermuteBLayout_ = layout::NoPermute, - /// - typename Enable = void> -struct DefaultGemmUniversal; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Real-valued GEMM kernels -// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Scatter result D by using an index array - bool ScatterD, - /// Permute result D - typename PermuteDLayout, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout> -struct DefaultGemmUniversal< - ElementA, LayoutA, - ComplexTransform::kNone, // transform A - kAlignmentA, ElementB, LayoutB, - ComplexTransform::kNone, // transform B - kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB, - ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout, - typename platform::enable_if< - !cutlass::is_complex::value>::type> { - using DefaultGemmKernel = typename kernel::DefaultGemm< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, - LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, - WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, - true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, - PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel; - - /// Universal kernel without StreamkFeature member type - template - class SelectBase - : public kernel::GemmUniversal {}; - - /// Universal kernel with StreamkFeature member type - template - class SelectBase - : public kernel::GemmUniversalStreamk< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, SwizzleT> {}; - - /// Select kernel by ThreadblockSwizzle's support for StreamkFeature - using GemmKernel = SelectBase; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Complex-valued GEMM kernels -// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -struct DefaultGemmUniversal< - ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB, - kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false, - false, layout::NoPermute, layout::NoPermute, layout::NoPermute, - typename platform::enable_if< - cutlass::is_complex::value>::type> { - using DefaultGemmKernel = typename kernel::DefaultGemmComplex< - ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, - InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, - TransformA, TransformB, Operator, false>::GemmKernel; - - /// Universal kernel without StreamkFeature member type - template - class SelectBase - : public kernel::GemmUniversal {}; - - /// Universal kernel with StreamkFeature member type - template - class SelectBase - : public kernel::GemmUniversalStreamk< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, SwizzleT> {}; - - /// Select kernel by ThreadblockSwizzle's support for StreamkFeature - using GemmKernel = SelectBase; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h deleted file mode 100644 index 411f673..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h +++ /dev/null @@ -1,366 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/arch/mma.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "gemm_universal_k.h" - -#include "default_gemm_universal.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "gemm_universal_base.h" - -#include "cutlass/layout/permute.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/*! - GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a - given GEMM computation (problem geometry and data references), it can be - reused across different GEMM problems having the geometry. (Once initialized, - details regarding problem geometry and references to workspace memory cannot - be updated.) - - The universal GEMM accommodates serial reductions, parallel reductions, - batched strided, and batched array variants. -*/ -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator_ = ElementC_, - /// Operator class tag - typename OperatorClass_ = arch::OpClassSimt, - /// Tag indicating architecture to tune for. This is the minimum SM that - /// supports the intended feature. The device kernel can be built - /// targeting any SM larger than this number. - typename ArchTag_ = arch::Sm70, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_ = typename DefaultGemmConfiguration< - OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_ = typename DefaultGemmConfiguration< - OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_ = typename DefaultGemmConfiguration< - OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< - OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_ = - threadblock::GemmIdentityThreadblockSwizzle<>, - /// Number of stages used in the pipelined mainloop - int Stages = - DefaultGemmConfiguration::kStages, - /// Access granularity of A matrix in units of elements - int AlignmentA = - DefaultGemmConfiguration::kAlignmentA, - /// Access granularity of B matrix in units of elements - int AlignmentB = - DefaultGemmConfiguration::kAlignmentB, - /// Operation performed by GEMM - typename Operator_ = typename DefaultGemmConfiguration< - OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::Operator, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Gather operand A by using an index array - bool GatherA = false, - /// Gather operand B by using an index array - bool GatherB = false, - /// Scatter result D by using an index array - bool ScatterD = false, - /// Permute result D - typename PermuteDLayout_ = layout::NoPermute, - /// Permute operand A - typename PermuteALayout_ = layout::NoPermute, - /// Permute operand B - typename PermuteBLayout_ = layout::NoPermute> -class GemmUniversal - : public GemmUniversalBase::GemmKernel> { - public: - using ElementAccumulator = ElementAccumulator_; - using OperatorClass = OperatorClass_; - using ArchTag = ArchTag_; - using ThreadblockShape = ThreadblockShape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using EpilogueOutputOp = EpilogueOutputOp_; - using ThreadblockSwizzle = ThreadblockSwizzle_; - using Operator = Operator_; - using PermuteDLayout = PermuteDLayout_; - using PermuteALayout = PermuteALayout_; - using PermuteBLayout = PermuteBLayout_; - static int const kStages = Stages; - static int const kAlignmentA = AlignmentA; - static int const kAlignmentB = AlignmentB; - static int const kAlignmentC = EpilogueOutputOp::kCount; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - - using Base = GemmUniversalBase::GemmKernel>; - - using Arguments = typename Base::Arguments; - using GemmKernel = typename Base::GemmKernel; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for column-major output exchanges problem size and -/// operand. -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for. This is the minimum SM that - /// supports the intended feature. The device kernel can be built - /// targeting any SM larger than this number. - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Epilogue output operator - typename EpilogueOutputOp_, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Access granularity of A matrix in units of elements - int AlignmentA, - /// Access granularity of B matrix in units of elements - int AlignmentB, - /// Operation performed by GEMM - typename Operator_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Scatter result D by using an index array - bool ScatterD, - /// Permute result D - typename PermuteDLayout_, - /// Permute operand A - typename PermuteALayout_, - /// Permute operand B - typename PermuteBLayout_> -class GemmUniversal< - ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, - layout::ColumnMajor, // partially specialized on LayoutC - ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, - WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, - Stages, AlignmentA, AlignmentB, Operator_, TransformA, TransformB, GatherA, - GatherB, ScatterD, PermuteDLayout_, PermuteALayout_, PermuteBLayout_> { - public: - using ElementA = ElementA_; - using LayoutA = LayoutA_; - using TensorRefA = TensorRef; - using ElementB = ElementB_; - using LayoutB = LayoutB_; - using TensorRefB = TensorRef; - using ElementC = ElementC_; - using LayoutC = layout::ColumnMajor; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - using ElementAccumulator = ElementAccumulator_; - using OperatorClass = OperatorClass_; - using ArchTag = ArchTag_; - using ThreadblockShape = ThreadblockShape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using EpilogueOutputOp = EpilogueOutputOp_; - using ThreadblockSwizzle = ThreadblockSwizzle_; - using Operator = Operator_; - using PermuteDLayout = PermuteDLayout_; - using PermuteALayout = PermuteALayout_; - using PermuteBLayout = PermuteBLayout_; - static int const kStages = Stages; - static int const kAlignmentA = AlignmentA; - static int const kAlignmentB = AlignmentB; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - - using UnderlyingOperator = typename GemmUniversal< - ElementB, typename layout::LayoutTranspose::type, ElementA, - typename layout::LayoutTranspose::type, ElementC, - layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, Operator, - kTransformB, kTransformA, GatherB, GatherA, ScatterD, PermuteDLayout, - PermuteBLayout, PermuteALayout>::Base; - - using GemmKernel = typename UnderlyingOperator::GemmKernel; - static int const kAlignmentC = EpilogueOutputOp::kCount; - - /// Argument structure - using Arguments = typename UnderlyingOperator::Arguments; - - private: - UnderlyingOperator underlying_operator_; - - public: - /// Constructs the GEMM. - GemmUniversal() {} - - /// Helper to construct a transposed equivalent for the underlying GEMM - /// operator - static Arguments to_underlying_arguments(Arguments const& args) { - return args.transposed_problem(); - } - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) { - return UnderlyingOperator::can_implement(to_underlying_arguments(args)); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) { - return UnderlyingOperator::get_workspace_size( - to_underlying_arguments(args)); - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) { - return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) { - return UnderlyingOperator::maximum_active_blocks(smem_capacity); - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), - workspace, stream); - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) { - return underlying_operator_.update(to_underlying_arguments(args), - workspace); - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { - return underlying_operator_.run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { return run(stream); } - - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) { - status = run(stream); - } - - return status; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp deleted file mode 100644 index 3b59cc8..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -// In cases where ProblemShape is not a tuple, this is used to check if the -// underlying problem shape type is aliased within or not. -// Used for dispatching GemmUniversal to 2.x API or 3.x API -template -struct IsCutlass3ArrayKernel : cute::false_type {}; - -template -struct IsCutlass3ArrayKernel< - ProblemShape, cute::void_t> - : cute::true_type {}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel - -//////////////////////////////////////////////////////////////////////////////// -#include "xe_gemm_array_cooperative.hpp" diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h deleted file mode 100644 index 0c923e8..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h +++ /dev/null @@ -1,844 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, - batched strided, and batched array variants. -*/ - -#pragma once - -// common -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/detail/layout.hpp" -#include "cutlass/detail/mma.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -#include "cutlass/kernel_launch.h" -#if !defined(__CUDACC_RTC__) - #include "cutlass/cluster_launch.hpp" - #include "cutlass/trace.h" -#endif // !defined(__CUDACC_RTC__) - -// 2.x -#include "gemm_universal_base.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" - -// 3.x -#include "gemm_universal.hpp" - -#if defined(CUTLASS_ENABLE_SYCL) - #include "cutlass/util/sycl_event_manager.hpp" -#endif - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::device { - -//////////////////////////////////////////////////////////////////////////////// - -/*! - GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel - of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. - - It manages the lifetime of the underlying `kernel::Params` struct, and exposes - APIs to create it from the host facing arguments. For power users, new static - methods are exposed in 3.x APIs that bypass the stateful methods or - args->params lowering. - - It supports kernel types that implement both the 2.x and 3.0 APIs, - however, this is done by specializing the implementation of - GemmUniversalAdapter on the two kernel API types, and thus, - GemmUniversalAdapter's behaviour might differ between the two specializations. -*/ -template -class GemmUniversalAdapter; - -//////////////////////////////////////////////////////////////////////////////// -////////////////////////////// CUTLASS 3.x API ///////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -// Work-around for some DispatchPolicy types not having a Stages member. -// In that case, the Stages value is 0. Most code should static_assert -// that the number of stages is valid. - -// Whether DispatchPolicy::Stages is valid. -// It should also be convertible to int, but if not, that will show up -// as a build error when GemmUniversalAdapter attempts to assign it to kStages. -template -struct has_Stages : cute::false_type {}; - -template -struct has_Stages> - : cute::true_type {}; - -template -constexpr int stages_member(DispatchPolicy) { - if constexpr (has_Stages::value) { - return DispatchPolicy::Stages; - } else { - return 0; - } -} - -} // namespace detail - -template -class GemmUniversalAdapter>::value>> { - public: - using GemmKernel = GetUnderlyingKernel_t; - using TileShape = typename GemmKernel::TileShape; - using ElementA = typename GemmKernel::ElementA; - using ElementB = typename GemmKernel::ElementB; - using ElementC = typename GemmKernel::ElementC; - using ElementD = typename GemmKernel::ElementD; - using ElementAccumulator = typename GemmKernel::ElementAccumulator; - using DispatchPolicy = typename GemmKernel::DispatchPolicy; - using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; - using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; - - // Map back to 2.x type as best as possible - using LayoutA = - gemm::detail::StrideToLayoutTagA_t; - using LayoutB = - gemm::detail::StrideToLayoutTagB_t; - using LayoutC = - gemm::detail::StrideToLayoutTagC_t; - using LayoutD = - gemm::detail::StrideToLayoutTagC_t; - - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - static ComplexTransform const kTransformA = - cute::is_same_v - ? ComplexTransform::kConjugate - : ComplexTransform::kNone; - static ComplexTransform const kTransformB = - cute::is_same_v - ? ComplexTransform::kConjugate - : ComplexTransform::kNone; - - // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 - using MathOperator = cutlass::arch::OpMultiplyAdd; - - using OperatorClass = cutlass::detail::get_operator_class_t< - typename CollectiveMainloop::TiledMma>; - - using ArchTag = typename GemmKernel::ArchTag; - - // NOTE: Assume identity swizzle for now - using ThreadblockSwizzle = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape - using ThreadblockShape = cutlass::gemm::GemmShape(TileShape{}), - cute::size<1>(TileShape{}), - cute::size<2>(TileShape{})>; - - using ClusterShape = cutlass::gemm::GemmShape< - cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; - - // Instruction shape is easy too, since we get that directly from our - // TiledMma's atom shape - using InstructionShape = cutlass::gemm::GemmShape< - cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), - cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), - cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; - - // Legacy: provide a correct warp count, but no reliable warp shape - static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; - - // Warp shape is not a primary API type in 3.x - // But we can best approximate it by inspecting the TiledMma - // For this, we make the assumption that we always have 4 warps along M, and - // rest along N, none along K We also always round up the warp count to 4 if - // the tiled mma is smaller than 128 threads - static constexpr int WarpsInMma = cute::max( - 4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); - static constexpr int WarpsInMmaM = 4; - static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); - using WarpCount = cutlass::gemm::GemmShape; - using WarpShape = - cutlass::gemm::GemmShape( - typename CollectiveMainloop::TiledMma{})) / - WarpsInMmaM, - CUTE_STATIC_V(cute::tile_size<1>( - typename CollectiveMainloop::TiledMma{})) / - WarpsInMmaN, - CUTE_STATIC_V(cute::tile_size<2>( - typename CollectiveMainloop::TiledMma{}))>; - - static int constexpr kStages = - detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); - - // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = - cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyA, ElementA, - typename CollectiveMainloop::TiledMma::ValTypeA>(); - static int constexpr kAlignmentB = - cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveMainloop::GmemTiledCopyB, ElementB, - typename CollectiveMainloop::TiledMma::ValTypeB>(); - static int constexpr kAlignmentC = - cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = - cutlass::detail::get_alignment_count_from_gmem_tiled_copy< - typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); - - using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - // Split-K preserves splits that are 128b aligned - static int constexpr kSplitKAlignment = cute::max( - 128 / sizeof_bits::value, 128 / sizeof_bits::value); - - /// Argument structure: User API - using Arguments = typename GemmKernel::Arguments; - /// Argument structure: Kernel API - using Params = typename GemmKernel::Params; - - private: - /// Kernel API parameters object - Params params_; - - public: - /// Access the Params structure - Params const& params() const { return params_; } - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) { - if (GemmKernel::can_implement(args)) { - return Status::kSuccess; - } else { - return Status::kInvalid; - } - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) { - size_t workspace_bytes = 0; - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { - workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * - size_t(cute::size<1>(TileShape{})); - } - - workspace_bytes += GemmKernel::get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - return workspace_bytes; - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args, void* workspace = nullptr) { - auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); - return GemmKernel::get_grid_shape(tmp_params); - } - - /// Computes the grid shape - static dim3 get_grid_shape(Params const& params) { - return GemmKernel::get_grid_shape(params); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int /* smem_capacity */ = -1) { - CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); - int max_active_blocks = -1; - int smem_size = GemmKernel::SharedStorageSize; - - // first, account for dynamic smem capacity if needed - cudaError_t result; - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - result = cudaFuncSetAttribute(device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " - << cudaGetErrorString(result)); - return -1; - } - } - - // query occupancy after setting smem size - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, device_kernel, - GemmKernel::MaxThreadsPerBlock, smem_size); - - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " - << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " - << workspace - << ", stream: " << (stream ? "non-null" : "null")); - - // Initialize the workspace - Status status = - GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); - if (status != Status::kSuccess) { - return status; - } - // Initialize the Params structure - params_ = GemmKernel::to_underlying_arguments(args, workspace); - // Don't set the function attributes - require the CudaHostAdapter to set - // it. - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - return Status::kSuccess; - } else { - // - // Account for dynamic smem capacity if needed - // - int smem_size = GemmKernel::SharedStorageSize; - - CUTLASS_ASSERT(cuda_adapter == nullptr); - -#if !defined(CUTLASS_ENABLE_SYCL) - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - cudaError_t result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " - << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } -#endif - } - return Status::kSuccess; - } - - /// Update API is preserved in 3.0, but does not guarantee a lightweight - /// update of params. - Status update(Arguments const& args, void* workspace = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); - - size_t workspace_bytes = get_workspace_size(args); - if (workspace_bytes > 0 && nullptr == workspace) { - return Status::kErrorWorkspaceNull; - } - - params_ = GemmKernel::to_underlying_arguments(args, workspace); - return Status::kSuccess; - } - - /// Primary run() entry point API that is static allowing users to create and - /// manage their own params. Supplied params struct must be construct by - /// calling GemmKernel::to_underlying_arguments() - static Status run(Params& params, sycl::queue& stream, - CudaHostAdapter* cuda_adapter = nullptr, - bool launch_with_pdl = false) { - CUTLASS_TRACE_HOST("GemmUniversal::run()"); - dim3 const block = GemmKernel::get_block_shape(); - dim3 const grid = get_grid_shape(params); - -#if defined(CUTLASS_ENABLE_SYCL) - const syclcompat::dim3 sycl_block(block.x, block.y, block.z); - const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); -#endif - - // configure smem size and carveout - int smem_size = GemmKernel::SharedStorageSize; - - Status launch_result{Status::kSuccess}; - // Use extended launch API only for mainloops that use it - if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); -#endif -#if !defined(CUTLASS_ENABLE_SYCL) - [[maybe_unused]] constexpr bool is_static_1x1x1 = - cute::is_static_v< - typename GemmKernel::DispatchPolicy::ClusterShape> and - cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; - [[maybe_unused]] dim3 cluster( - cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - - // Dynamic cluster support - [[maybe_unused]] dim3 fallback_cluster = dim3{0, 0, 0}; - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || - GemmKernel::ArchTag::kMinComputeCapability == 101) { - if constexpr (!cute::is_static_v< - typename GemmKernel::DispatchPolicy::ClusterShape>) { - fallback_cluster = params.hw_info.cluster_shape_fallback; - cluster = params.hw_info.cluster_shape; - } - } - - [[maybe_unused]] void* kernel_params[] = {¶ms}; - - if constexpr (kEnableCudaHostAdapter) { - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - if (launch_with_pdl) { - CUTLASS_TRACE_HOST( - "GemmUniversal::run() does not support launching with PDL and " - "a custom cuda adapter."); - return Status::kErrorInternal; - } - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching kernel with CUDA host adapter"); - #endif - if constexpr (is_static_1x1x1) { - launch_result = cuda_adapter->launch(grid, block, smem_size, stream, - kernel_params, 0); - } else { - launch_result = - cuda_adapter->launch(grid, cluster, fallback_cluster, block, - smem_size, stream, kernel_params, 0); - } - } else { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA " - "host adapter is null"); - return Status::kErrorInternal; - } - } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - [[maybe_unused]] void const* kernel = - (void const*)device_kernel; - static constexpr bool kClusterLaunch = - GemmKernel::ArchTag::kMinComputeCapability == 90; - if constexpr (kClusterLaunch) { - if constexpr (is_static_1x1x1) { - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching static 1x1x1 kernel"); - #endif - launch_result = cutlass::kernel_launch( - grid, block, smem_size, stream, params, launch_with_pdl); - if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports failure"); - } - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports success"); - } - #endif - } else { - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching dynamic cluster kernel"); - #endif - launch_result = - ClusterLauncher::launch(grid, cluster, block, smem_size, stream, - kernel, kernel_params, launch_with_pdl); - } - } - - else { - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || - GemmKernel::ArchTag::kMinComputeCapability == 101 || - GemmKernel::ArchTag::kMinComputeCapability == 120) { - if constexpr (is_static_1x1x1) { - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching static 1x1x1 kernel"); - #endif - launch_result = cutlass::kernel_launch( - grid, block, smem_size, stream, params, launch_with_pdl); - if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports " - "failure"); - } - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports " - "success"); - } - #endif - } else { - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching kernel with fall-back " - "cluster"); - #endif - launch_result = ClusterLauncher::launch_with_fallback_cluster( - grid, cluster, fallback_cluster, block, smem_size, stream, - kernel, kernel_params, launch_with_pdl); - } - } - } - } -#endif - } else { - launch_result = Status::kSuccess; - cutlass::arch::synclog_setup(); - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - void* kernel_params[] = {¶ms}; -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching kernel with CUDA host adapter"); -#endif - launch_result = cuda_adapter->launch(grid, block, smem_size, stream, - kernel_params, 0); - - } else { - CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); - return Status::kErrorInternal; - } - } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); -#if defined(CUTLASS_ENABLE_SYCL) - // sycl::queue q = stream; // ? *stream : - // syclcompat::get_default_queue(); - #if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) - using namespace syclcompat::experimental; - if constexpr (cute::is_same_v) { - auto event = launch>( - launch_policy{sycl_grid, sycl_block, - local_mem_size { - static_cast(smem_size) - }}, - q, params); - EventManager::getInstance().addEvent(event); - } else { - auto event = launch>( - launch_policy{ - sycl_grid, sycl_block, - local_mem_size{static_cast(smem_size)} - #if defined(SYCL_INTEL_TARGET) - , - kernel_properties { - sycl_exp::sub_group_size - } - #endif - }, - stream, params); - EventManager::getInstance().addEvent(event); - } - #else - #if defined(SYCL_INTEL_TARGET) - constexpr bool allow_subgroup_size_prop = true; - #else - constexpr bool allow_subgroup_size_prop = false; - #endif - auto kernel_props = [] { - constexpr bool is_device_agnostic = - cute::is_same_v; - if constexpr (!allow_subgroup_size_prop or is_device_agnostic) { - using EmptyProperties = - decltype(sycl::ext::oneapi::experimental::properties()); - return syclcompat::experimental::kernel_properties< - EmptyProperties>{}; - } else { - return syclcompat::experimental::kernel_properties{ - sycl::ext::oneapi::experimental::sub_group_size< - DispatchPolicy::SubgroupSize>}; - } - }(); - syclcompat::experimental::launch_properties launch_props{ - sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), - }; - syclcompat::experimental::launch_policy policy{ - sycl_grid, sycl_block, launch_props, kernel_props}; - auto event = - syclcompat::experimental::launch>( - policy, stream, params); - EventManager::getInstance().addEvent(event); - #endif // !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) -#else - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); - #endif - launch_result = cutlass::kernel_launch( - grid, block, smem_size, stream, params, launch_with_pdl); - if (launch_result != Status::kSuccess) { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports failure"); - } - #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cutlass::kernel_launch reports success"); - } - #endif -#endif - } - } - - cudaError_t result = cudaGetLastError(); - if (cudaSuccess == result && Status::kSuccess == launch_result) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST( - "GemmUniversal::run: cudaGetLastError reports success"); -#endif - return Status::kSuccess; - } else { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); - return Status::kErrorInternal; - } - } - - // - // Non-static launch overloads that first create and set the internal params - // struct of this kernel handle. - // - - /// Launches the kernel after first constructing Params internal state from - /// supplied arguments. - Status run(Arguments const& args, void* workspace, sycl::queue& stream, - CudaHostAdapter* cuda_adapter = nullptr, - bool launch_with_pdl = false) { - Status status = initialize(args, workspace, stream, cuda_adapter); - - if (Status::kSuccess == status) { - status = run(params_, stream, cuda_adapter, launch_with_pdl); - } - return status; - } - - /// Launches the kernel after first constructing Params internal state from - /// supplied arguments. - Status operator()(Arguments const& args, void* workspace, sycl::queue& stream, - CudaHostAdapter* cuda_adapter = nullptr, - bool launch_with_pdl = false) { - return run(args, workspace, stream, cuda_adapter, launch_with_pdl); - } - - /// Overload that allows a user to re-launch the same kernel without updating - /// internal params struct. - Status run(sycl::queue& stream, CudaHostAdapter* cuda_adapter = nullptr, - bool launch_with_pdl = false) { - return run(params_, stream, cuda_adapter, launch_with_pdl); - } - - /// Overload that allows a user to re-launch the same kernel without updating - /// internal params struct. - Status operator()(sycl::queue& stream, - CudaHostAdapter* cuda_adapter = nullptr, - bool launch_with_pdl = false) { - return run(params_, stream, cuda_adapter, launch_with_pdl); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -////////////////////////////// CUTLASS 2.x API ///////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalAdapter< - GemmKernel_, cute::enable_if_t>::value>> { - public: - using GemmKernel = GetUnderlyingKernel_t; - - static bool const kInternalTranspose = - !cutlass::epilogue::threadblock::detail::is_2x_evt_v< - typename GemmKernel::Epilogue> && // 2.x EVT does not require - // internal transpose - cute::is_same::value; - - using ThreadblockShape = typename GemmKernel::Mma::Shape; - using WarpShape = typename GemmKernel::WarpShape; - using InstructionShape = typename GemmKernel::InstructionShape; - - // warp-level, arch-level (instruction), math operator - using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - - // Operator class and arch tag extract bottom-up - // set it for top-level gemm device-level template - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - - // Type, layout, and complex transform deliberately exchanged with B - using MapArguments = kernel::detail::MapArguments< - typename GemmKernel::ElementA, typename GemmKernel::LayoutA, - GemmKernel::kTransformA, GemmKernel::kAlignmentA, - typename GemmKernel::ElementB, typename GemmKernel::LayoutB, - GemmKernel::kTransformB, GemmKernel::kAlignmentB, - typename GemmKernel::LayoutC, kInternalTranspose>; - - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static int const kAlignmentA = MapArguments::kAlignmentA; - - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - static int const kAlignmentB = MapArguments::kAlignmentB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename MapArguments::LayoutC; - static int const kAlignmentC = GemmKernel::kAlignmentC; - - // C and D same type for 2.x kernel - using ElementD = ElementC; - using LayoutD = LayoutC; - - using TensorRefA = TensorRef; - using TensorRefB = TensorRef; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - - static int const kStages = GemmKernel::Mma::kStages; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using UnderlyingOperator = GemmUniversalBase; - using Arguments = typename UnderlyingOperator::Arguments; - - private: - UnderlyingOperator underlying_operator_; - - public: - /// Constructs the GEMM. - GemmUniversalAdapter() {} - - /// Helper to construct a transposed equivalent for the underlying GEMM - /// operator - static Arguments to_underlying_arguments(Arguments const& args) { - if (kInternalTranspose) { - return args.transposed_problem(); - } else { - return args; - } - } - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - return UnderlyingOperator::can_implement(to_underlying_arguments(args), - cuda_adapter); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), - cuda_adapter); - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) { - return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) { - return UnderlyingOperator::maximum_active_blocks(smem_capacity); - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), - workspace, stream, cuda_adapter); - } - - /// Lightweight update given a subset of arguments. - Status update(Arguments const& args) { - return underlying_operator_.update(to_underlying_arguments(args)); - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return underlying_operator_.run(stream, cuda_adapter); - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = initialize(args, workspace, stream, cuda_adapter); - - if (status == Status::kSuccess) { - status = run(stream, cuda_adapter); - } - - return status; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::device - -//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h deleted file mode 100644 index b909318..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h +++ /dev/null @@ -1,524 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief The universal GEMM accommodates streamk, batched strided, and batched - array variants. -*/ - -#pragma once - -#if defined(__CUDACC_RTC__) - #include -#else - #include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/device_kernel.h" -#include "cutlass/cuda_host_adapter.hpp" - -#include "cutlass/gemm/gemm.h" -#include "gemm_universal_k.h" - -#include "default_gemm_universal.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" - -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalBase { - public: - using GemmKernel = GemmKernel_; - - /// Boolean indicating whether the CudaHostAdapter is enabled - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - using ThreadblockShape = typename GemmKernel::Mma::Shape; - - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; - - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - - /// Numerical accumulation element type - using ElementAccumulator = typename GemmKernel::Mma::ElementC; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; - - /// Argument structure - using Arguments = typename GemmKernel::Arguments; - - /// Index of the GEMM Kernel within the CudaHostAdapter - static int32_t const kGemmKernelIndex = 0; - - /// Kernel dynamic shared memory allocation requirement - /// Update the kernel function's shared memory configuration for the current - /// device - static constexpr size_t kSharedStorageSize = - sizeof(typename GemmKernel::SharedStorage); - - protected: - // - // Device properties (uniform across all instances of the current thread) - // - - // Device ordinal - CUTLASS_THREAD_LOCAL static int device_ordinal_; - - /// Device SM count - CUTLASS_THREAD_LOCAL static int device_sms_; - - /// Kernel SM occupancy (in thread blocks) - CUTLASS_THREAD_LOCAL static int sm_occupancy_; - - protected: - /// Initialize static thread-local members for the thread's current device, - /// if necessary. - static Status init_device_props() { - CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); - - cudaError_t cudart_result; - - // Get current device ordinal - int current_ordinal; - cudart_result = cudaGetDevice(¤t_ordinal); - if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " - << cudaGetErrorString(cudart_result)); - return Status::kErrorInternal; - } - - // Done if matches the current static member - if (current_ordinal == device_ordinal_) { - // Already initialized - return Status::kSuccess; - } - - // Update SM count member - cudart_result = cudaDeviceGetAttribute( - &device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); - if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " - << cudaGetErrorString(cudart_result)); - return Status::kErrorInternal; - } - - // If requires more than 48KB: configure for extended, dynamic shared memory - if constexpr (kSharedStorageSize >= (48 << 10)) { - cudart_result = cudaFuncSetAttribute( - Kernel2, cudaFuncAttributeMaxDynamicSharedMemorySize, - kSharedStorageSize); - if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " - << cudaGetErrorString(cudart_result)); - return Status::kErrorInternal; - } - } - - // Update SM occupancy member - cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - &sm_occupancy_, Kernel2, GemmKernel::kThreadCount, - kSharedStorageSize, cudaOccupancyDisableCachingOverride); - if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned " - "error " - << cudaGetErrorString(cudart_result)); - return Status::kErrorInternal; - } - - // Update device ordinal member on success - device_ordinal_ = current_ordinal; - - CUTLASS_TRACE_HOST( - " " - "device_ordinal: (" - << device_ordinal_ - << "), " - "device_sms: (" - << device_sms_ - << "), " - "sm_occupancy: (" - << sm_occupancy_ - << ") " - "smem_size: (" - << kSharedStorageSize - << ") " - "GemmKernel::kThreadCount: (" - << GemmKernel::kThreadCount << ")"); - - return Status::kSuccess; - } - - protected: - // - // Instance data members - // - - /// Kernel parameters - typename GemmKernel::Params params_; - - /// Initialize params member - Status init_params(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - int32_t device_sms = 0; - int32_t sm_occupancy = 0; - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - - // - // Occupancy query using CudaHostAdapter::query_occupancy(). - // - - if (cuda_adapter) { - Status status = cuda_adapter->query_occupancy( - &device_sms, &sm_occupancy, kGemmKernelIndex, - GemmKernel::kThreadCount, kSharedStorageSize); - - CUTLASS_ASSERT(status == Status::kSuccess); - - if (status != Status::kSuccess) { - return status; - } - } else { - return Status::kErrorInternal; - } - } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - - // Initialize static device properties, if necessary - Status result = init_device_props(); - - if (result != Status::kSuccess) { - return result; - } - - // - // Use thread-local static members for occupancy query initialized by call - // to `init_device_props()` - // - - device_sms = device_sms_; - sm_occupancy = sm_occupancy_; - } - - // Initialize params member - params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy); - return Status::kSuccess; - } - - public: - //--------------------------------------------------------------------------------------------- - // Stateless API - //--------------------------------------------------------------------------------------------- - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); - - if (!kEnableCudaHostAdapter || cuda_adapter) { - dim3 grid = get_grid_shape(args, cuda_adapter); - - if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) { - return Status::kErrorInvalidProblem; - } - } else { - // - // With a null host adapter, a conservative grid shape is computed and - // required to conform to CUDA grid dimension limits. - // - - int64_t logicalGridM = - (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / - ThreadblockShape::kM; - int64_t logicalGridN = - (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / - ThreadblockShape::kN; - int32_t logicalGridL = args.batch_count; - - if ((int64_t(std::numeric_limits::max()) < logicalGridM) || - (int64_t(std::numeric_limits::max()) < logicalGridN) || - (int32_t(std::numeric_limits::max()) < logicalGridL)) { - return Status::kErrorInvalidProblem; - } - } - - return GemmKernel::can_implement(args); - } - - /// Returns the workspace size (in bytes) needed for the problem - /// geometry expressed by these arguments - static size_t get_workspace_size(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); - - // Initialize parameters from args - GemmUniversalBase base; - if (base.init_params(args, cuda_adapter) != Status::kSuccess) { - return 0; - } - - // Get size from parameters - size_t workspace_bytes = base.params_.get_workspace_size(); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - return workspace_bytes; - } - - /// Returns the grid extents in thread blocks to launch - static dim3 get_grid_shape(Arguments const& args, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); - - // Initialize parameters from args - GemmUniversalBase base; - if (base.init_params(args, cuda_adapter) != Status::kSuccess) { - return dim3(0, 0, 0); - } - - // Get dims from parameters - dim3 grid_dims = base.params_.get_grid_dims(); - - CUTLASS_TRACE_HOST(" tiled_shape: " - << base.params_.get_tiled_shape() << "\n" - << " grid_dims: {" << grid_dims << "}"); - - return grid_dims; - } - - /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks(CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - - int32_t device_sms = 0; - int32_t sm_occupancy = 0; - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - - if (cuda_adapter) { - Status status = cuda_adapter->query_occupancy( - &device_sms, &sm_occupancy, kGemmKernelIndex, - GemmKernel::kThreadCount, kSharedStorageSize); - - CUTLASS_ASSERT(status == Status::kSuccess); - - if (status != Status::kSuccess) { - return -1; - } - } else { - return -1; - } - } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - // Initialize static device properties, if necessary - if (init_device_props() != Status::kSuccess) { - return -1; - } - - sm_occupancy = sm_occupancy_; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); - return sm_occupancy; - } - - //--------------------------------------------------------------------------------------------- - // Stateful API - //--------------------------------------------------------------------------------------------- - - /// Initializes GEMM state from arguments and workspace memory - Status initialize(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " - << workspace - << ", stream: " << (stream ? "non-null" : "null")); - - // Initialize parameters from args - Status result = init_params(args, cuda_adapter); - if (result != Status::kSuccess) { - return result; - } - - // Assign and prepare workspace memory - if (args.mode == GemmUniversalMode::kGemm) { - return params_.init_workspace(workspace, stream); - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments. - Status update(Arguments const& args) { - CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); - params_.update(args); - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); - - // Configure grid and block dimensions - dim3 block(GemmKernel::kThreadCount, 1, 1); - dim3 grid = params_.get_grid_dims(); - - // Launch kernel - CUTLASS_TRACE_HOST( - " " - "grid: (" - << grid - << "), " - "block: (" - << block - << "), " - "SMEM: (" - << kSharedStorageSize << ")"); - - cutlass::arch::synclog_setup(); - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - void* kernel_params[] = {¶ms_}; - return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, - kernel_params, 0); - } else { - return Status::kErrorInternal; - } - } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - -#if defined(CUTLASS_ENABLE_SYCL) - const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); - const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); - - sycl::queue q = stream ? *stream : syclcompat::get_default_queue(); - syclcompat::experimental::launch>( - syclcompat::experimental::launch_policy{ - sycl_grid, sycl_block, - #if defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) - sycl::ext::oneapi::experimental::work_group_scratch_size( - kSharedStorageSize) - #else - syclcompat::experimental::local_mem_size{ - static_cast(kSharedStorageSize)} - #endif - }, - q, params_); -#else - Kernel2<<>>(params_); -#endif - - // Query for errors - cudaError_t result = cudaGetLastError(); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " - << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return run(stream, cuda_adapter); - } - - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = initialize(args, workspace, stream, cuda_adapter); - - if (status == Status::kSuccess) { - status = run(stream, cuda_adapter); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Static initializers -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Device ordinal -template -CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_ordinal_ = -1; - -/// Device SM count -template -CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; - -/// Kernel SM occupancy (in thread blocks) -template -CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h deleted file mode 100644 index 19871ee..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h +++ /dev/null @@ -1,649 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/fast_math.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" -#include "cutlass/semaphore.h" -#include "gemm_universal.hpp" - -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/params_universal_base.h" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversal< - Mma_, Epilogue_, ThreadblockSwizzle_, void, - // 3.x kernels use the first template argument to define the ProblemShape - // We use this invariant to SFINAE dispatch against either the 2.x API or - // the 3.x API - cute::enable_if_t::value || - IsCutlass3ArrayKernel::value)>> { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = - Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = const_max( - 128 / sizeof_bits::value, 128 / sizeof_bits::value); - - // - // Structures - // - - /// Argument structure - struct Arguments : UniversalArgumentsBase { - // - // Data members - // - - typename EpilogueOutputOp::Params epilogue; - - void const* ptr_A; - void const* ptr_B; - void const* ptr_C; - void* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - - typename LayoutA::Stride stride_a; - typename LayoutB::Stride stride_b; - typename LayoutC::Stride stride_c; - typename LayoutC::Stride stride_d; - - typename LayoutA::Stride::LongIndex lda; - typename LayoutB::Stride::LongIndex ldb; - typename LayoutC::Stride::LongIndex ldc; - typename LayoutC::Stride::LongIndex ldd; - - int const* ptr_gather_A_indices; - int const* ptr_gather_B_indices; - int const* ptr_scatter_D_indices; - - // - // Methods - // - - Arguments() - : ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_gather_A_indices(nullptr), - ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr) {} - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, - typename EpilogueOutputOp::Params epilogue, void const* ptr_A, - void const* ptr_B, void const* ptr_C, void* ptr_D, - int64_t batch_stride_A, int64_t batch_stride_B, - int64_t batch_stride_C, int64_t batch_stride_D, - typename LayoutA::Stride stride_a, - typename LayoutB::Stride stride_b, - typename LayoutC::Stride stride_c, - typename LayoutC::Stride stride_d, - int const* ptr_gather_A_indices = nullptr, - int const* ptr_gather_B_indices = nullptr, - int const* ptr_scatter_D_indices = nullptr) - : UniversalArgumentsBase(mode, problem_size, batch_count, - batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_D(ptr_D), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - stride_a(stride_a), - stride_b(stride_b), - stride_c(stride_c), - stride_d(stride_d), - ptr_gather_A_indices(ptr_gather_A_indices), - ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { - lda = 0; - ldb = 0; - ldc = 0; - ldd = 0; - CUTLASS_TRACE_HOST( - "GemmUniversal::Arguments::Arguments() - problem_size: " - << problem_size); - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, - typename EpilogueOutputOp::Params epilogue, void const* ptr_A, - void const* ptr_B, void const* ptr_C, void* ptr_D, - int64_t batch_stride_A, int64_t batch_stride_B, - int64_t batch_stride_C, int64_t batch_stride_D, - typename LayoutA::Stride::LongIndex lda, - typename LayoutB::Stride::LongIndex ldb, - typename LayoutC::Stride::LongIndex ldc, - typename LayoutC::Stride::LongIndex ldd, - int const* ptr_gather_A_indices = nullptr, - int const* ptr_gather_B_indices = nullptr, - int const* ptr_scatter_D_indices = nullptr) - : UniversalArgumentsBase(mode, problem_size, batch_count, - batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_D(ptr_D), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - lda(lda), - ldb(ldb), - ldc(ldc), - ldd(ldd), - ptr_gather_A_indices(ptr_gather_A_indices), - ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { - stride_a = make_Coord(lda); - stride_b = make_Coord(ldb); - stride_c = make_Coord(ldc); - stride_d = make_Coord(ldd); - CUTLASS_TRACE_HOST( - "GemmUniversal::Arguments::Arguments() - problem_size: " - << problem_size); - } - - /// Returns arguments for the transposed problem - Arguments transposed_problem() const { - Arguments args(*this); - - std::swap(args.problem_size.m(), args.problem_size.n()); - std::swap(args.ptr_A, args.ptr_B); - std::swap(args.lda, args.ldb); - std::swap(args.stride_a, args.stride_b); - std::swap(args.batch_stride_A, args.batch_stride_B); - std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); - - return args; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - : UniversalParamsBase { - using ParamsBase = - UniversalParamsBase; - - // - // Data members - // - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::Params params_D; - - typename EpilogueOutputOp::Params output_op; - - void* ptr_A; - void* ptr_B; - void* ptr_C; - void* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - - int* ptr_gather_A_indices; - int* ptr_gather_B_indices; - int* ptr_scatter_D_indices; - - // - // Host dispatch API - // - - /// Default constructor - Params() = default; - - /// Constructor - Params(Arguments const& args, /// GEMM application arguments - int device_sms, /// Number of SMs on the device - int sm_occupancy) /// Kernel SM occupancy (in thread blocks) - : ParamsBase(args, device_sms, sm_occupancy), - params_A(args.lda - ? make_Coord_with_padding(args.lda) - : args.stride_a), - params_B(args.ldb - ? make_Coord_with_padding(args.ldb) - : args.stride_b), - params_C(args.ldc - ? make_Coord_with_padding(args.ldc) - : args.stride_c), - params_D(args.ldd - ? make_Coord_with_padding(args.ldd) - : args.stride_d), - output_op(args.epilogue), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_D(args.ptr_D), - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), - ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), - ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) {} - - /// Lightweight update given a subset of arguments. - void update(Arguments const& args) { - CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); - - // Update input/output pointers - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); - ptr_D = args.ptr_D; - - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - this->batch_stride_D = args.batch_stride_D; - - ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); - ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); - ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); - - output_op = args.epilogue; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - public: - // - // Host dispatch API - // - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { - CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); - - static int const kAlignmentA = - (cute::is_same>::value) ? 32 - : (cute::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = - (cute::is_same>::value) ? 32 - : (cute::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = - (cute::is_same>::value) ? 32 - : (cute::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (cute::is_same::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } else if (cute::is_same::value) { - isAMisaligned = problem_size.m() % kAlignmentA; - } else if (cute::is_same>::value || - cute::is_same>::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (cute::is_same::value) { - isBMisaligned = problem_size.n() % kAlignmentB; - } else if (cute::is_same::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } else if (cute::is_same>::value || - cute::is_same>::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (cute::is_same::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } else if (cute::is_same::value) { - isCMisaligned = problem_size.m() % kAlignmentC; - } else if (cute::is_same>::value || - cute::is_same>::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) { - return can_implement(args.problem_size); - } - - public: - // - // Device-only API - // - - // Factory invocation - CUTLASS_DEVICE - static void invoke(Params const& params, SharedStorage& shared_storage) { - GemmUniversal op; - op(params, shared_storage); - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) { - ThreadblockSwizzle threadblock_swizzle; - run_with_swizzle(params, shared_storage, threadblock_swizzle); - } - - /// Executes one GEMM with an externally-provided swizzling function - CUTLASS_DEVICE - void run_with_swizzle(Params const& params, SharedStorage& shared_storage, - ThreadblockSwizzle& threadblock_swizzle) { - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } else if (params.mode == GemmUniversalMode::kBatched) { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast( - params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast( - params.ptr_B)[threadblock_tile_offset.k()]; - } - - syncthreads(); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{ - offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int thread_idx = ThreadIdxX(); - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, - thread_idx, tb_offset_A, params.ptr_gather_A_indices); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, - thread_idx, tb_offset_B, params.ptr_gather_B_indices); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = ThreadIdxX() % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = - (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + - threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC* ptr_C = static_cast(params.ptr_C); - ElementC* ptr_D = static_cast(params.ptr_D); - - // - // Fetch pointers based on mode. - // - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - if (params.mode == GemmUniversalMode::kGemm) { - // If performing a reduction via split-K, fetch the initial - // synchronization - if (params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is - // currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), - params.grid_tiled_shape.k()); - } - } else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } else if (params.mode == GemmUniversalMode::kBatched) { - ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } else if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast( - params.ptr_C)[threadblock_tile_offset.k()]; - ptr_D = static_cast( - params.ptr_D)[threadblock_tile_offset.k()]; - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, ptr_C, params.problem_size.mn(), thread_idx, - threadblock_offset, params.ptr_scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, ptr_D, params.problem_size.mn(), thread_idx, - threadblock_offset, params.ptr_scatter_D_indices); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator - // construction - if (params.mode == GemmUniversalMode::kGemm && - params.grid_tiled_shape.k() > 1) { - // For subsequent threadblocks, the source matrix is held in the 'D' - // tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (params.mode == GemmUniversalMode::kGemm && - params.grid_tiled_shape.k() > 1) { - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp deleted file mode 100644 index bd49242..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp +++ /dev/null @@ -1,562 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -// #include "cutlass/epilogue/collective/collective_epilogue.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" -#include "cutlass/detail/layout.hpp" - -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class CollectiveEpilogue { - static_assert(cutlass::detail::dependent_false, - "Could not find an epilogue specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class CollectiveEpilogue { - public: - // - // Type Aliases - // - using DispatchPolicy = IntelXeXMX16Group; - using CtaTileMNK = CtaTileMNK_; - using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using ElementAccumulator = ElementC_; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - using CopyOpG2R = CopyOpG2R_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpR2G = CopyOpR2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - - using ThreadEpilogueOp = - typename fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2R; - using GmemTiledCopyD = cute::conditional_t && - not cute::is_void_v, - CopyOpR2G, XE_2D_U32x8x16_ST_N>; - using ElementOutput = ElementD; - using ElementCompute = ElementAccumulator; - using ElementSource = typename FusionCallbacks::ElementSource; - using ElementScalar = typename FusionCallbacks::ElementScalar; - static constexpr FloatRoundStyle RoundStyle = - FloatRoundStyle::round_to_nearest; - - static_assert( - cute::is_same_v< - typename FusionCallbacks::Operation, - fusion::LinearCombination>, - "Only Linear Combination Epilogue is supported for Grouped GEMM at the " - "moment."); - - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - static_assert(cute::rank(CtaTileMNK{}) == 3, - "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(InternalStrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]"); - - static_assert(std::is_same_v, - "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, - "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, - "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, - "Copy operation to shared memory is not supported"); - - using CopyThreadShape = Shape<_1, Int>; - using Trait_C = Copy_Traits; - using XE_Copy_C = decltype(make_tiled_copy( - Copy_Atom{}, Layout{}, - make_layout( - shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); - using Trait_D = Copy_Traits; - using XE_Copy_D = decltype(make_tiled_copy( - Copy_Atom{}, Layout{}, - make_layout( - shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); - - private: - // constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_source_supported = false; - constexpr static bool is_destination_supported = - not cute::is_void_v && not cute::is_void_v; - - public: - using EmptyType = cute::tuple<>; - using SmemCStorage = EmptyType; - using SmemDStorage = EmptyType; - - struct TensorStorageImpl : cute::tuple { - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - }; - - struct SharedStorage { - using TensorStorage = TensorStorageImpl; - - TensorStorage tensors; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - - using TensorC = - decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), - make_shape(0, 0, 0), InternalStrideC{})); //(m, n) - using TensorD = - decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), - make_shape(0, 0, 0), InternalStrideD{})); //(m, n) - using EpilogueTensors = cute::tuple; - - // Host side epilogue arguments - struct Arguments { - typename FusionCallbacks::Arguments thread{}; - ElementC const** ptr_C; - StrideC dC; - ElementD** ptr_D; - StrideD dD; - }; - - // Device side epilogue params - struct Params { - typename FusionCallbacks::Params thread{}; - XE_Copy_C xe_load_c; - XE_Copy_D xe_store_d; - ElementC const** ptr_C; - StrideC dC; - ElementD** ptr_D; - StrideD dD; - }; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, - [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only - // rank-3 (MNK) - auto problem_shape_MNL = repeat_like( - typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); - auto [M, N, L] = problem_shape_MNL; - - XE_Copy_C xe_load_c = {}; - if constexpr (is_source_supported) { - ElementC const* ptr_C_first_batch = - reinterpret_cast(args.ptr_C); - TensorC mC_mnl = - make_tensor(make_gmem_ptr(ptr_C_first_batch), - make_layout(make_shape(M, N, L), InternalStrideC{})); - xe_load_c = {xe_load_c.with(mC_mnl)}; - } - - XE_Copy_D xe_store_d = {}; - if constexpr (is_destination_supported) { - ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); - TensorD mD_mnl = - make_tensor(make_gmem_ptr(ptr_D_first_batch), - make_layout(make_shape(M, N, L), InternalStrideD{})); - xe_store_d = {xe_store_d.with(mD_mnl)}; - } - - return {FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, - workspace), - xe_load_c, - xe_store_d, - args.ptr_C, - args.dC, - args.ptr_D, - args.dD}; - } - - template - static size_t get_workspace_size(ProblemShape const& problem_shape, - Arguments const& args) { - return 0; - } - - template - static cutlass::Status initialize_workspace( - ProblemShape const& problem_shape, Arguments const& args, void* workspace, - cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - template - static bool can_implement(ProblemShape problem_shape, Arguments const& args) { - constexpr int copy_alignment_bits = 128; - constexpr int batch_alignment_bits = 512; - - bool implementable = true; - bool fusion_implementable = true; - - for (int i = 0; i < problem_shape.groups(); ++i) { - auto problem_shape_MNKL = - append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - - if constexpr (is_destination_supported) { - constexpr int min_aligned_elements_D = - copy_alignment_bits / sizeof_bits::value; - implementable &= - cutlass::detail::check_alignment( - cute::make_shape(M, N, L), InternalStrideD{}); - if (L > 1) { - constexpr int min_batch_aligned_elements_D = - batch_alignment_bits / sizeof_bits::value; - implementable &= - get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; - } - } - - if constexpr (is_source_supported) { - constexpr int min_aligned_elements_C = - copy_alignment_bits / sizeof_bits::value; - implementable &= - cutlass::detail::check_alignment( - cute::make_shape(M, N, L), InternalStrideC{}); - if (L > 1) { - constexpr int min_batch_aligned_elements_C = - batch_alignment_bits / sizeof_bits::value; - implementable &= - get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; - } - } - - fusion_implementable = - fusion_implementable && - FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); - } - - if (!implementable) { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " - "requirements for XE 2D copy.\n"); - } - - if (!fusion_implementable) { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements " - "for FusionCallbacks.\n"); - } - - return implementable && fusion_implementable; - } - - CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_, - TensorStorage const& shared_storage_) - : params(params_), - fusion_callbacks(params_.thread, shared_storage_.thread) {} - - CUTLASS_DEVICE - bool is_producer_load_needed() const { - return fusion_callbacks.is_producer_load_needed(); - } - - template - CUTLASS_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - Accumulator accumulators, TiledMma tiled_mma, - int thread_idx, - LoadStoreTensor const& load_store_tensors) { - (void)tiled_mma; - using namespace cute; - - static_assert(cute::rank(CtaTileMNK{}) == 3, - "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(InternalStrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]"); - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - static constexpr auto BLK_M = get<0>(CtaTileMNK{}); - static constexpr auto BLK_N = get<1>(CtaTileMNK{}); - static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, - // "assertion fail"); - static constexpr auto ATOM_M = - get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = - get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = - get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static_assert( - BLK_M % ATOM_M == 0 && BLK_N % ATOM_N == 0 && BLK_K % ATOM_K == 0, - "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); - static constexpr auto SG_M = BLK_M / ATOM_M; - static constexpr auto SG_N = BLK_N / ATOM_N; - static constexpr auto SG_K = BLK_K / ATOM_K; - using SubgroupTileShape = - Shape; - - static constexpr int FragsM = - get<0>(SubgroupTileShape{}) / - get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = - get<1>(SubgroupTileShape{}) / - get<1>(MmaAtomShape()); // B frags per sub_group - - static constexpr int FragmentSize = - (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto m_sg = get_sub_group_id() / ATOM_N; - auto n_sg = get_sub_group_id() % ATOM_N; - - // Get the layout and reconstruct the MN mapping equivalent to the old - // get_layoutS_MN() - auto layoutS_TV = params.xe_store_d.get_layoutS_TV(); - auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); - auto layoutS_MN = right_inverse(layoutS_TV).with_shape(mn_shape); - using EpilogueTile = decltype(layoutS_MN.shape()); - - auto sg_local_m_coord = get_sub_group_id() / ATOM_N; - auto sg_local_n_coord = get_sub_group_id() % ATOM_N; - - auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; - auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; - auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); - - bool is_C_load_needed = - is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Represent the full output tensor - Tensor mD_mnl = cute::get_xe_tensor(make_shape(M, N, L)); - - // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = - local_tile(mD_mnl, take<0, 2>(CtaTileMNK{}), - make_coord(m_coord, n_coord, l_coord)); // (BLK_M,BLK_N) - - // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0, 2>(SubgroupTileShape{}), - make_coord(m_sg, n_sg)); // (SG_M,SG_N) - - auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); - Tensor tCgD = thread_xe_store_d.partition_D(gD); - - Tensor trC = - make_tensor(Shape>{}); - Tensor trD_compute = - make_tensor(Shape>{}); - - // Because Sm90 uses shared memory, they are not tied to using the same - // accumulator values for MMA and Epilogue. But because we are operating - // directly in the accumulators, we need to be sure that we are operating on - // the same values. - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M, N)); // (M,N) - Tensor cD = local_tile(mD_crd, take<0, 2>(SubgroupTileShape{}), - make_coord(sg_m_coord, sg_n_coord)); - Tensor cD_mn = local_tile(mD_crd, take<0, 2>(CtaTileMNK{}), - make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_g2r.partition_S( - flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor tRS_cD = - make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - - // Get the fusion callbacks - // Arguments passed here relate to sub-group tiles, rather than CTA - // (work-group) tiles - constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); // TODO(Codeplay): this is not correct - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - SubgroupTileShape{}, - sg_coord, - tiled_mma, - EpilogueTile{}, - params.xe_store_d, - cD, - residue_mn, - tRS_cD, - residue_mn, - trC, - thread_idx, - }; - auto cst_callbacks = - fusion_callbacks.template get_consumer_store_callbacks( - cst_args); - - cst_callbacks.begin(); - - auto acc_frag = recast>(accumulators); - auto trD_compute_frag = - recast>(trD_compute); - - Tensor trD = make_tensor(Shape>{}); - auto trD_frag = recast>(trD); - - constexpr int ValuesLoaded = FragsM * FragsN * FragmentSize * SubgroupSize * - ATOM_M * ATOM_N * ATOM_K; - constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); - static_assert( - ValuesLoaded == MN, - "the total elements loaded by all threads should be the same as MxN"); - - auto synchronize = [&]() {}; - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < FragsN; epi_n++) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < FragsM; epi_m++) { - if (is_C_load_needed) { - // coordinates for C and D are the same - copy(params.xe_load_c.with(get<0>(load_store_tensors)), - tCgD(_, epi_m, epi_n), trC); - } - - cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); - - auto acc_frag_mn = acc_frag(_, epi_m, epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = - cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); - } - cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, - (epi_m == FragsM - 1 && epi_n == FragsN - 1), - trD_compute_frag); - - if constexpr (is_destination_supported) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = - cutlass::NumericArrayConverter{}( - trD_compute_frag(i)); - } - copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, - tCgD(_, epi_m, epi_n)); - } - } - } - - cst_callbacks.end(); - } - - template - CUTLASS_DEVICE auto update_tensor_shape_stride( - int32_t const& next_group, ProblemShape_MNKL const& problem_shape_mnkl) { - auto [M, N, K, L] = problem_shape_mnkl; - - TensorC mC_mnl; - TensorD mD_mnl; - if constexpr (is_source_supported) { - ElementC const* ptr_C_curr_batch = - reinterpret_cast(params.ptr_C[next_group]); - mC_mnl = - make_tensor(make_gmem_ptr(ptr_C_curr_batch), - make_layout(make_shape(M, N, L), params.dC[next_group])); - } - - if constexpr (is_destination_supported) { - ElementD* ptr_D_curr_batch = - reinterpret_cast(params.ptr_D[next_group]); - mD_mnl = - make_tensor(make_gmem_ptr(ptr_D_curr_batch), - make_layout(make_shape(M, N, L), params.dD[next_group])); - } - return cute::make_tuple(mC_mnl, mD_mnl); - } - - private: - Params const& params; - FusionCallbacks fusion_callbacks; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp deleted file mode 100644 index a2abb4b..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp +++ /dev/null @@ -1,360 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct CollectiveMma, TileShape_, - ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, - TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, - SmemCopyAtomB_, TransformB_> { - // - // Type Aliases - // - using DispatchPolicy = MainloopIntelXeXMX16Group; - using WorkgroupTileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using InternalStrideA = cute::remove_pointer_t; - using ElementB = ElementB_; - using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - static_assert( - platform::is_same::value, - "MainloopIntelXeXMX16Array requires that A and B have same type."); - - static_assert(std::is_same_v, - "Transformation for A is not currently supported on Intel PVC"); - static_assert(std::is_same_v, - "Transformation for B is not currently supported on Intel PVC"); - - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - - static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); - static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); - static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); - - static constexpr auto ATOM_M = - get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = - get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = - get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); - static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); - static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); - using SubgroupTileShape = - Shape; - - static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - - using Copy_A = typename Copy_Traits< - GmemTiledCopyA, InternalStrideA>::template DefaultTiledCopy; - using Copy_B = typename Copy_Traits< - GmemTiledCopyB, InternalStrideB>::template DefaultTiledCopy; - - using TensorMKL = - decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), - make_shape(0, 0, 0), InternalStrideA{})); //(m, k) - using TensorNKL = - decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), - make_shape(0, 0, 0), InternalStrideB{})); //(n, k) - using MainloopTensors = cute::tuple; - // Host side kernel arguments - struct Arguments { - ElementA const** ptr_A; - StrideA dA; - ElementB const** ptr_B; - StrideB dB; - }; - - struct Params { - ElementA const** ptr_A; - StrideA dA; - ElementB const** ptr_B; - StrideB dB; - }; - - // - // Methods - // - - CollectiveMma() = default; - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, - void* workspace) { - (void)workspace; - - auto problem_shape_MNK = repeat_like( - typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); - ; - auto init_M = get<0>(problem_shape_MNK); - auto init_N = get<1>(problem_shape_MNK); - auto init_K = get<2>(problem_shape_MNK); - - return Params{args.ptr_A, args.dA, args.ptr_B, args.dB}; - } - - template - static bool can_implement(ProblemShape problem_shapes, - Arguments const& args) { - constexpr int copy_alignment_bits = 128; - constexpr int batch_alignment_bits = 512; - auto problem_shape_MNKL = append<4>(problem_shapes, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - bool implementable = true; - - constexpr int min_aligned_elements_A = - copy_alignment_bits / sizeof_bits::value; - constexpr int min_aligned_elements_B = - copy_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_A = - batch_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_B = - batch_alignment_bits / sizeof_bits::value; - for (int i = 0; i < problem_shapes.groups(); i++) { - auto problem_shape_MNKL = - append<4>(problem_shapes.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - - implementable &= cutlass::detail::check_alignment( - cute::make_shape(M, K, L), InternalStrideA{}); - implementable &= cutlass::detail::check_alignment( - cute::make_shape(N, K, L), InternalStrideB{}); - - if (L > 1) { - implementable &= - get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; - implementable &= - get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; - } - } - - if (!implementable) { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " - "requirements for XE 2D copy.\n"); - } - - return implementable; - } - - /// Perform a subgroup-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE void operator()(FrgTensorD& accum, TensorA gA, TensorB gB, - FrgTensorC const& src_accum, - KTileIterator k_tile_iter, - int const& k_tile_count, - BlkCoord const& blk_coord, int const& K_start, - int const& thread_idx, Params const& mainloop, - LoadTensors const& load_tensors) { - static_assert(is_rmem::value, - "D tensor must be rmem resident."); - static_assert(is_rmem::value, - "C tensor must be rmem resident."); - - (void)thread_idx; - - Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; - Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; - - auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); - - // Instantiate the MMA object and get thread slice - TiledMma tiled_mma; - // TODO(Codeplay): see if we can make this nicer - // To make all work items in a subgroup have the same global tensors pass in - // the index of work item 0 in each subgroup - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = - sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; - auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); - - // Partition global counting tensors for MMA - Tensor tCgA = thr_mma.partition_A(gA); - Tensor tCgB = thr_mma.partition_B(gB); - - Tensor tCrA = make_tensor( - make_fragment_layout(tiled_copy_a, tCgA(_, _, _, 0).shape())); - Tensor tCrB = make_tensor( - make_fragment_layout(tiled_copy_b, tCgB(_, _, _, 0).shape())); - - // Retile registers for copies - Tensor tArA = thr_copy_A.retile_D(tCrA); - Tensor tBrB = thr_copy_B.retile_D(tCrB); - - // Retile global counting tensors for copies - Tensor tAgA = thr_copy_A.retile_S(tCgA); - Tensor tBgB = thr_copy_B.retile_S(tCgB); - - auto tiled_prefetch_a = - cute::prefetch_selector, Int>, Num_SGs>( - tiled_copy_a); - auto tiled_prefetch_b = - cute::prefetch_selector, Int>, Num_SGs>( - tiled_copy_b); - auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); - auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); - - // Partition global tile for prefetch - auto pAgA = thr_prefetch_A.partition_S(gA); - auto pBgB = thr_prefetch_B.partition_S(gB); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { - print("======================= A: \n"); - print(" gA : "); - print(gA); - print("\n"); - print("tCgA : "); - print(tCgA); - print("\n"); - print("tAgA : "); - print(tAgA); - print("\n"); - - print("===================== B :\n"); - print(" gB : "); - print(gB); - print("\n"); - print("tCgB : "); - print(tCgB); - print("\n"); - print("tBgB : "); - print(tBgB); - print("\n"); - - print("===================== Config: \n"); - print(" threads per workgroup : "); - print(MaxThreadsPerBlock); - print("\n"); - print(" SubgroupTileShape : "); - print(SubgroupTileShape{}); - print("\n"); - } -#endif - - // - // Mainloop - // - const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); - constexpr int barrier_scope = 2; - int prefetch_k = k_start_idx; - - CUTLASS_PRAGMA_UNROLL - for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); - } - - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; - k_tile++, prefetch_k++) { - barrier_arrive(barrier_scope); - // Copy gmem to rmem for the first k_tile - copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); - copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); - - if (prefetch_k < k_tile_count) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); - } - - cute::gemm(tiled_mma, tCrA, tCrB, accum); - barrier_wait(barrier_scope); - } - } - - template - CUTLASS_DEVICE auto update_tensor_shape_stride( - Params const& mainloop_params, int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - const int32_t M = get<0>(problem_shape_mnkl); - const int32_t N = get<1>(problem_shape_mnkl); - const int32_t K = get<2>(problem_shape_mnkl); - - ElementA const* ptr_A_curr_batch = - reinterpret_cast(mainloop_params.ptr_A[next_group]); - ElementB const* ptr_B_curr_batch = - reinterpret_cast(mainloop_params.ptr_B[next_group]); - - Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), - make_shape(M, K, (int32_t)1), - mainloop_params.dA[next_group]); - Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), - make_shape(N, K, (int32_t)1), - mainloop_params.dB[next_group]); - - return cute::make_tuple(mA, mB); - } -}; - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp deleted file mode 100644 index ca749c3..0000000 --- a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp +++ /dev/null @@ -1,234 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include // cute::DefaultCopy -#include // cute::is_base_of_v -// #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "xe_array_epilogue.hpp" -#include "xe_callbacks.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Used to specify epilogue subtile shape or dispatch to automatic computation -// of subtile shape -struct EpilogueTileAuto {}; - -// Used to let the builder pick the epilogue schedule automatically. -// Can be overridden with kernel schedule tags in -// cutlass/gemm/dispatch_policy.hpp -struct EpilogueScheduleAuto {}; - -template < - class ArchTag, class OpClass, class TileShape_MNK, class ClusterShape_MNK, - class EpilogueTileType, class ElementAccumulator, class ElementCompute, - class ElementC, class GmemLayoutTagC, int AlignmentC, class ElementD, - class GmemLayoutTagD, int AlignmentD, class EpilogueScheduleType, - class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination< - ElementD, ElementCompute, ElementC, ElementCompute>, - class Enable = void> -struct CollectiveBuilder { - static_assert(cutlass::detail::dependent_false, - "Could not build a collective epilogue for given parameters."); -}; - -// helper sub-builder for epilogue fusion callbacks (for internal use by -// CollectiveBuilder only) -namespace detail { - -// callbacks builder with operation tag -template -struct CallbacksBuilder { - using Callbacks = fusion::FusionCallbacks; -}; - -// callbacks builder with callbacks passthrough -template -struct CallbacksBuilder>> { - using Callbacks = FusionCallbacks; -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::collective { - -namespace detail { -template -struct FusionOpInfo { - static_assert(cutlass::detail::dependent_false, - "Could not find a builder specialization."); -}; - -template -struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - - template - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinearCombination, - TileShape_MNK, EpilogueTile>; -}; - -template