diff --git a/.github/workflows/ci-slang-test.yml b/.github/workflows/ci-slang-test.yml index 4a889aeaf5..1dd1840970 100644 --- a/.github/workflows/ci-slang-test.yml +++ b/.github/workflows/ci-slang-test.yml @@ -43,7 +43,7 @@ on: jobs: test-slang: runs-on: ${{ fromJSON(inputs.runs-on) }} - timeout-minutes: 30 + timeout-minutes: 60 defaults: run: shell: bash @@ -93,6 +93,11 @@ jobs: slang_test_args+=("-expected-failure-list" "tests/expected-failure-linux-gpu.txt") fi + # Skip most neural tests in debug builds to reduce CI time + if [[ "${{ inputs.config }}" == "debug" ]]; then + slang_test_args+=("-skip-list" "tests/skip-list-debug.txt") + fi + # Execute slang-test with all arguments "$bin_dir/slang-test" "${slang_test_args[@]}" - name: Run Slang examples @@ -174,7 +179,7 @@ jobs: # Includes both slangpy pytest tests and slangpy-samples examples. test-slangpy: runs-on: ${{ fromJSON(inputs.runs-on) }} - timeout-minutes: 30 + timeout-minutes: 60 if: inputs.full-gpu-tests && (github.event_name == 'pull_request' || inputs.config == 'release') defaults: run: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0eebf725c4..d7dff50781 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -146,6 +146,7 @@ jobs: -DSLANG_GENERATORS_PATH=build-platform-generators/bin \ -DSLANG_ENABLE_EXAMPLES=OFF \ -DSLANG_ENABLE_RELEASE_LTO=ON \ + -DSLANG_STANDARD_MODULE_DEVELOP_BUILD=OFF \ "-DSLANG_SLANG_LLVM_FLAVOR=$( [[ "${{matrix.build-slang-llvm}}" = "true" ]] && echo "USE_SYSTEM_LLVM" || echo "DISABLE")" \ ${{matrix.extra-cmake-flags}} diff --git a/CMakeLists.txt b/CMakeLists.txt index 268d2ae4ba..f83045c021 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,6 +130,12 @@ option(SLANG_ENABLE_IR_BREAK_ALLOC "Enable _debugUID on IR allocation") option(SLANG_ENABLE_ASAN "Enable ASAN (address sanitizer)") option(SLANG_ENABLE_COVERAGE "Enable code coverage instrumentation") +option( + SLANG_STANDARD_MODULE_DEVELOP_BUILD + "Enable development build for standard modules (enables UNIT_TEST macro). Disable for release builds." + ON +) + option(SLANG_ENABLE_PREBUILT_BINARIES "Enable using prebuilt binaries" ON) option(SLANG_ENABLE_GFX "Enable gfx targets" ON) option(SLANG_ENABLE_SLANGD "Enable language server target" ON) diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 0476ab6aee..be4925b751 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -6462,7 +6462,7 @@ struct IsSaturated // ==================================================================================== template -__device__ inline void wmmaLoad(uint32_t* regs, const ElemT* ptr, int stride) +__device__ inline void wmmaLoad(uint32_t* regs, const void* ptr, int stride) { constexpr int nregs = RegisterCount::value; @@ -6527,7 +6527,7 @@ __device__ inline void wmmaLoad(uint32_t* regs, const ElemT* ptr, int stride) // ==================================================================================== template -__device__ inline void wmmaStore(ElemT* ptr, const uint32_t* regs, int stride) +__device__ inline void wmmaStore(void* ptr, const uint32_t* regs, int stride) { constexpr int nregs = RegisterCount::value; @@ -6623,7 +6623,7 @@ inline unsigned __device__ Pack32Helper(unsigned char value) // The dimensions of the fragment are specified by M, N, K which are totally determined during // compile time, so slang already did the pre-filter on the shape & type combination. -template +template struct WmmaFragment { __device__ WmmaFragment() {} @@ -6647,25 +6647,19 @@ struct WmmaFragment void __device__ fill(T value) { unsigned packed = Pack32Helper(value); - - // Manually assign to prevent register coalescing - regs[0] = packed; - regs[1] = packed; - regs[2] = packed; - regs[3] = packed; - regs[4] = packed; - regs[5] = packed; - regs[6] = packed; - regs[7] = packed; + constexpr int nregs = RegisterCount::value; + for (int i = 0; i < nregs; i++) + { + regs[i] = packed; + } } __device__ This operator*(T b) { - constexpr int nregs = RegisterCount::value; This result; // This loop will be unrolled by the compiler becuase nregs is constexpr - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) * b); } @@ -6674,11 +6668,10 @@ struct WmmaFragment __device__ This operator*(const This& b) { - constexpr int nregs = RegisterCount::value; This result; // This loop will be unrolled by the compiler becuase nregs is constexpr - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) * b.get(i)); } @@ -6687,10 +6680,9 @@ struct WmmaFragment __device__ This operator/(const This& other) { - constexpr int nregs = RegisterCount::value; This result; - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) / other.get(i)); } @@ -6699,10 +6691,9 @@ struct WmmaFragment __device__ This operator-(const This& other) { - constexpr int nregs = RegisterCount::value; This result; - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) - other.get(i)); } @@ -6711,10 +6702,9 @@ struct WmmaFragment __device__ This operator-() { - constexpr int nregs = RegisterCount::value; This result; - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, -get(i)); } @@ -6723,10 +6713,9 @@ struct WmmaFragment __device__ This operator+(const This& other) { - constexpr int nregs = RegisterCount::value; This result; - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) + other.get(i)); } @@ -6735,10 +6724,9 @@ struct WmmaFragment __device__ This operator%(const This& other) { - constexpr int nregs = RegisterCount::value; This result; - for (int i = 0; i < nregs; i++) + for (int i = 0; i < GetLength(); i++) { result.set(i, get(i) % other.get(i)); } @@ -6751,7 +6739,7 @@ struct WmmaFragment // If the data type is different, we need to copy element by element. // Since the shape of two matrices are the same, they have the same // number of elements. - for (int i = 0; i < elements_per_thread; i++) + for (int i = 0; i < GetLength(); i++) { set(i, static_cast(other.get(i))); } @@ -6763,7 +6751,6 @@ struct WmmaFragment // - index 1: bits [8:15] of regs[0] // - index 2: bits [16:23] of regs[0] // - index 3: bits [24:31] of regs[0] - // - index 4: bits [0:7] of regs[1], etc. __device__ T get(int index) const { if constexpr (sizeof(T) == 4) @@ -6848,40 +6835,52 @@ struct WmmaFragment wmmaStore(buffer + element, regs, stride); } + template + void __device__ Store(U* buffer, uint stride) + { + // Force compile-time check, so we know the template parameter comibination is valid. + (void)RegisterCount::value; + wmmaStore(buffer, regs, stride * sizeof(U) / sizeof(T)); + } + template static This __device__ Load(T* buffer, uint element, uint stride) { - WmmaFragment fragment; + WmmaFragment fragment; // Force compile-time check, so we know the template parameter comibination is valid. (void)RegisterCount::value; wmmaLoad(fragment.regs, buffer + element, stride); + fragment.m_layout = layout; + return fragment; + } + + template + static This __device__ Load(U* buffer, uint stride) + { + WmmaFragment fragment; + // Force compile-time check, so we know the template parameter comibination is valid. + (void)RegisterCount::value; + wmmaLoad(fragment.regs, buffer, stride * sizeof(U) / sizeof(T)); + fragment.m_layout = layout; return fragment; } - static __device__ uint32_t GetLength() { return This::elements_per_thread; } + static constexpr __device__ uint32_t GetLength() { return This::elements_per_thread; } // For referencing those template parameters outside the struct using ElementType = T; static constexpr int m_M = M; static constexpr int m_N = N; static constexpr int m_K = K; - static constexpr Layout m_layout = MatrixLayout; - - // Maximum registers needed across all fragment types and data types - static constexpr int MAX_REGS = 8; - uint32_t regs[MAX_REGS] = {}; + Layout m_layout = Layout::RowMajor; - static constexpr uint32_t elements_per_warp = (R == MatrixUse::MatrixA) ? (M * K) - : (R == MatrixUse::MatrixB) ? (K * N) - : (M * N); + // Register Count requirement + static constexpr int RegsCount = RegisterCount::value; + unsigned regs[RegsCount] = {}; - static_assert(elements_per_warp % 32 == 0, "Total elements per warp must be divisible by 32"); - - static constexpr uint32_t elements_per_thread = elements_per_warp / 32; - static constexpr uint32_t bytes_per_thread = elements_per_thread * sizeof(T); - static constexpr uint32_t registers_per_thread = (bytes_per_thread + 3) / 4; + static constexpr uint32_t elements_per_thread = RegsCount * (4 / sizeof(T)); }; // ==================================================================================== @@ -7350,12 +7349,10 @@ template< int M, int N, int K, - Layout layoutA, - Layout layoutB, bool saturatingAccumulation> WmmaFragment __device__ coopMatMulAdd( - WmmaFragment matA, - WmmaFragment matB, + WmmaFragment matA, + WmmaFragment matB, WmmaFragment matC) { constexpr ShapeCombination shape = (M == 16 && N == 16 && K == 16) ? ShapeCombination::m16n16k16 @@ -7364,11 +7361,60 @@ WmmaFragment __device__ coopMatMulAdd( : ShapeCombination::m32n8k16; WmmaFragment matD; - MMAHelper::eval( - matD, - matA, - matB, - matC); + uint32_t encodedLayout = (matA.m_layout == Layout::RowMajor ? 1 : 0) << 1 | + (matB.m_layout == Layout::RowMajor ? 1 : 0); + + switch (encodedLayout) + { + // 00011 + case 0x3: + MMAHelper< + AType, + BType, + CType, + DType, + shape, + Layout::RowMajor, + Layout::RowMajor, + saturatingAccumulation>::eval(matD, matA, matB, matC); + break; + // 00010 + case 0x2: + MMAHelper< + AType, + BType, + CType, + DType, + shape, + Layout::RowMajor, + Layout::ColMajor, + saturatingAccumulation>::eval(matD, matA, matB, matC); + break; + // 0001 + case 0x01: + MMAHelper< + AType, + BType, + CType, + DType, + shape, + Layout::ColMajor, + Layout::RowMajor, + saturatingAccumulation>::eval(matD, matA, matB, matC); + break; + // 0000 + case 0x00: + MMAHelper< + AType, + BType, + CType, + DType, + shape, + Layout::ColMajor, + Layout::ColMajor, + saturatingAccumulation>::eval(matD, matA, matB, matC); + break; + } return matD; } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 043b004547..50ad8dae01 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5437,7 +5437,7 @@ When generating code for other targets, this parameter is ignored and has no eff __generic __magic_type(HLSLStructuredBufferType) __intrinsic_type($(kIROp_HLSLStructuredBufferType)) -struct StructuredBuffer +struct StructuredBuffer : IArray { /// Get the dimensions of the buffer. @@ -6629,7 +6629,7 @@ When generating code for other targets, this parameter is ignored and has no eff @see `StructuredBuffer`, `AppendStructuredBuffer`, `ConsumeStructuredBuffer` @category buffer_types **/ -struct $(item.name) +struct $(item.name) : IRWArray { /// Decrements the object's hidden counter. /// @return The post-decremented counter value. @@ -27156,6 +27156,28 @@ struct CoopMat } } + [ForceInline] + [require(cooperative_matrix)] + void Store< + let matrixLayout : CoopMatMatrixLayout, + >(Ptr data, uint stride) + { + __target_switch + { + case spirv: + let alignment = 16U; + spirv_asm + { + OpCooperativeMatrixStoreKHR $data $this $matrixLayout $stride Aligned !alignment; + }; + case cuda: + if (matrixLayout == CoopMatMatrixLayout.RowMajor) + __intrinsic_asm "$0.Store(($1)->m_data, 0, $2)"; + else + __intrinsic_asm "$0.Store(($1)->m_data, 0, $2)"; + } + } + /// Stores the cooperative matrix into a groupshared (workgroup) memory array with a different element type. /// Vulkan only. /// @param matrixLayout The memory layout (RowMajor or ColMajor) to use when storing. @@ -27181,6 +27203,29 @@ struct CoopMat }; } + [ForceInline] + [require(cooperative_matrix)] + void Store< + let matrixLayout : CoopMatMatrixLayout, + U, + >(Ptr data, uint stride) + { + __target_switch + { + case spirv: + let alignment = 16U; + spirv_asm + { + OpCooperativeMatrixStoreKHR $data $this $matrixLayout $stride Aligned !alignment; + }; + case cuda: + if (matrixLayout == CoopMatMatrixLayout.RowMajor) + __intrinsic_asm "$0.Store(($1), $2)"; + else + __intrinsic_asm "$0.Store(($1), $2)"; + } + } + /// Stores the cooperative matrix into a groupshared (workgroup) memory array of vectors. /// Vulkan only. /// @param matrixLayout The memory layout (RowMajor or ColMajor) to use when storing. @@ -27191,7 +27236,7 @@ struct CoopMat /// @param element The starting element index in the array. /// @param stride The stride in elements between consecutive rows (for row major) or columns (for column major). [ForceInline] - [require(cooperative_matrix)] + [require(cooperative_matrix_spirv)] void Store< let matrixLayout : CoopMatMatrixLayout, U, @@ -27350,6 +27395,28 @@ $} } } + [ForceInline] + [require(cooperative_matrix)] + static This Load< + let matrixLayout : CoopMatMatrixLayout, + >(Ptr data, uint stride) + { + __target_switch + { + case spirv: + let alignment = 16U; + return spirv_asm + { + result:$$CoopMat = OpCooperativeMatrixLoadKHR $data $matrixLayout $stride Aligned !alignment; + }; + case cuda: + if (matrixLayout == CoopMatMatrixLayout.RowMajor) + __intrinsic_asm "$TR::Load($0, 0, $1)"; + else + __intrinsic_asm "$TR::Load($0, 0, $1)"; + } + } + /// Loads a cooperative matrix from a groupshared (workgroup) memory array with a different element type. /// Vulkan only. /// @param matrixLayout The memory layout (RowMajor or ColMajor) of the data in the array. @@ -27376,6 +27443,29 @@ $} }; } + [ForceInline] + [require(cooperative_matrix)] + static This Load< + let matrixLayout : CoopMatMatrixLayout, + U, + >(Ptr data, uint stride) + { + __target_switch + { + case spirv: + let alignment = 16U; + return spirv_asm + { + result:$$CoopMat = OpCooperativeMatrixLoadKHR $data $matrixLayout $stride Aligned !alignment; + }; + case cuda: + if (matrixLayout == CoopMatMatrixLayout.RowMajor) + __intrinsic_asm "$TR::Load($0, $1)"; + else + __intrinsic_asm "$TR::Load($0, $1)"; + } + } + /// Loads a cooperative matrix from a groupshared (workgroup) memory array of vectors. /// Vulkan only. /// @param matrixLayout The memory layout (RowMajor or ColMajor) of the data in the array. @@ -28383,8 +28473,6 @@ CoopMat coopMatMulAdd< $T0::m_M, $T0::m_N, $T0::m_K, - $T0::m_layout, - $T1::m_layout, true >($0, $1, $2))"; else @@ -28396,8 +28484,6 @@ CoopMat coopMatMulAdd< $T0::m_M, $T0::m_N, $T0::m_K, - $T0::m_layout, - $T1::m_layout, false >($0, $1, $2))"; diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index e7cda137d1..eec6d63678 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1237,6 +1237,7 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); + peepholeOptimizeGlobalScope(targetProgram, this->module); applySparseConditionalConstantPropagationForGlobalScope( this->module, targetProgram, diff --git a/source/standard-modules/neural/CMakeLists.txt b/source/standard-modules/neural/CMakeLists.txt index 83f9f02ac3..4381d3e860 100644 --- a/source/standard-modules/neural/CMakeLists.txt +++ b/source/standard-modules/neural/CMakeLists.txt @@ -44,10 +44,17 @@ else() set(SLANG_COMPILER_DEPENDENCY) endif() +# Determine compiler flags based on build type +if(SLANG_STANDARD_MODULE_DEVELOP_BUILD) + set(NEURAL_MODULE_DEFINES "-DUNIT_TEST") +else() + set(NEURAL_MODULE_DEFINES "") +endif() + # Build neural.slang-module and output to the output directory add_custom_command( OUTPUT ${neural_module_file} - COMMAND ${SLANG_COMPILER} neural.slang -o ${SLANG_NEURAL_MODULE_FILE_NAME} + COMMAND ${SLANG_COMPILER} ${NEURAL_MODULE_DEFINES} neural.slang -o ${SLANG_NEURAL_MODULE_FILE_NAME} DEPENDS ${neural_copied_files} ${SLANG_COMPILER_DEPENDENCY} WORKING_DIRECTORY ${neural_output_dir} VERBATIM diff --git a/source/standard-modules/neural/accelerate-vector-coopmat.slang b/source/standard-modules/neural/accelerate-vector-coopmat.slang new file mode 100644 index 0000000000..aff467a080 --- /dev/null +++ b/source/standard-modules/neural/accelerate-vector-coopmat.slang @@ -0,0 +1,1273 @@ +// Unit test mode is used for unit testing the tiled MMA implementation. +// So we can test this single file by providing -DUNIT_TEST to the compiler. +implementing neural; + +#ifdef UNIT_TEST +#define VISIBILITY_LEVEL public +#else +#define VISIBILITY_LEVEL internal +#endif + +VISIBILITY_LEVEL enum TargetEnum : uint32_t +{ + CUDA = 0, + SPIR_V = 1, +} + +public enum ExecutionMode : uint32_t +{ + Inference = 0, + Training = 1, +} + +[ForceInline] +internal uint getWaveId() +{ + __target_switch + { + case cuda: + uint3 tid = cudaThreadIdx(); + uint3 blockDim = cudaBlockDim(); + uint flattenedTid = tid.x + tid.y * blockDim.x + tid.z * blockDim.x * blockDim.y; + return flattenedTid / WaveGetLaneCount(); + case spirv: + return spirv_asm { + OpCapability GroupNonUniform; + result:$$uint = OpLoad builtin(SubgroupId:uint); + }; + } +} + +[ForceInline] +internal int getWaveCount() +{ + // Note, we always require the threads count is multiple of the subgroup size, therefore we don't need to round up the result. + __stage_switch + { + case compute: + __target_switch + { + case cuda: + uint3 blockDim = cudaBlockDim(); + int warpsPerBlock = (blockDim.x * blockDim.y * blockDim.z) >> 5; + return warpsPerBlock; + case spirv: + uint3 workGroupSize = WorkgroupSize(); + int subGroupSize = WaveGetLaneCount(); + return (workGroupSize.x * workGroupSize.y * workGroupSize.z) / subGroupSize; + } + default: + // We need this because WorkgroupSize() call requires compute stage only. + static_assert(false, "Only support compute stage"); + return 0; + } +} + +// We can't convert T* to uint4* on shared memory in slang, therefore, we will provide two versions of shared memory pointer +// and implement both of them, and document about the difference between them. +internal typealias SPtr = Ptr; + +internal struct CoopMatShape + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + static const int ROW_A = 16; + static const int COLUMN_A = 16; + + // TODO: Currently, we only support floating point data type, therefore, we can always use 16x16x16. + // Once in the future we want to extend to support integer type, we can uncomment the following code. + static const int ROW_B = 16; // (Target == TargetEnum.CUDA) ? 16 : (sizeof(T) == 1 ? 32 : 16); + static const int COLUMN_B = 16; + + static const int ROW_C = 16; + static const int COLUMN_C = 16; + + static const int SizeAInElements = ROW_A * COLUMN_A; + static const int SizeBInElements = ROW_B * COLUMN_B; + static const int SizeCInElements = ROW_C * COLUMN_C; + + // Because the vectorized load is along K dimension (row of A and column of B), so we need to check + // if K is aligned with the vector size (ElementCountPerVector), we can use the vectorized load, + // otherwise, we need to use the scalar load. + // ElementCountPerVector is measured in T. + static const uint ElementCountPerVector = sizeof(uint4) / sizeof(half); + + // Since A and B can only be half type, ElementCountPerVector is fixed. However, C can be both half and float. + static const uint ElementCountPerVectorMatC = sizeof(uint4) / sizeof(T); + + static const uint CoopMatASizeInVector = (ROW_A * COLUMN_A) / ElementCountPerVector; + static const uint CoopMatBSizeInVector = (ROW_B * COLUMN_B) / ElementCountPerVector; + static const uint CoopMatCSizeInVector = (ROW_C * COLUMN_C) / ElementCountPerVectorMatC; +} + +struct TileInfo + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + typealias MatShape = CoopMatShape; + + static const bool IsAlignedVector = (K % MatShape.ElementCountPerVector) == 0; + + // For the Tile A, how many cooperative matrices in a row + static const int NCoopMatRow = TransposeA ? (K + MatShape.COLUMN_A - 1) / MatShape.COLUMN_A : + (M + MatShape.ROW_A - 1) / MatShape.ROW_A; + + // Notice that because of our specific workload, NTilesColumn is always warp-size (as long as it's more than half warp) or half warp (<= half warp). + // For the Tile B, how many cooperative matrices in a column + static const int NCoopMatColumn = (N + MatShape.COLUMN_B - 1) / MatShape.COLUMN_B; + + // Total number of cooperative matrices that consist the result of C = A * B + static const int NCoopMat = NCoopMatRow * NCoopMatColumn; + + // When A * B, k is the shared dimension, and when A^T * B, m is the shared dimension. + static const uint SharedDimensionSize = TransposeA ? Uint4AlignedM : Uint4AlignedK; + + // `Uint4AlignedK` is the K value that is aligned with the uint4 vector size. Because we always do the vectorized load + // along K dimension, so we need to align the K value with the vector size. + // Measured in T. + static const int Uint4AlignedK = ((K + (MatShape.ElementCountPerVector - 1)) / MatShape.ElementCountPerVector) * MatShape.ElementCountPerVector; + static const int Uint4AlignedM = ((M + (MatShape.ElementCountPerVector - 1)) / MatShape.ElementCountPerVector) * MatShape.ElementCountPerVector; + + // The height of tile A measure in T and uint4. + static const uint HeightInElementsTileA = NCoopMatRow * MatShape.ROW_A; + static const uint HeightInVectorTileA = HeightInElementsTileA / MatShape.ElementCountPerVector; + + // In order to achieve the high throughput, we will write data to shared memory in vectorized way. + // e.g. if T is half, it's 2 bytes, so uint4 can store 8 elements. + // Since the size in bytes of the tile in one row is T.COLUMN * sizeof(T), + // the number of vectorized columns in one row of cooperative matrix A is T.COLUMN * sizeof(T) / sizeof(uint4). + // `1` means one element of uint4. + static const uint WidthInVectorTileA = MatShape.COLUMN_A / MatShape.ElementCountPerVector; + + // WidthInElementTileA is the number of elements in one row of the tile in original data type. + static const uint WidthInElementTileA = MatShape.COLUMN_A; + + // Similarly, the number of vectorized rows in one column of cooperative matrix B is T.ROW * sizeof(T) / sizeof(uint4). + // `1` means one element of uint4 + static const uint HeightInVectorTileB = MatShape.ROW_B / MatShape.ElementCountPerVector; + + // So HeightInElementTileB is the number of elements in one column of the tile, measured in T. + static const uint HeightInElementTileB = HeightInVectorTileB * MatShape.ElementCountPerVector; + + // The height of tile B measure in T and uint4. + static const int WidthInElementsTileB = NCoopMatColumn * MatShape.COLUMN_B; + static const int WidthInVectorTileB = WidthInElementsTileB / MatShape.ElementCountPerVector; + + static const uint HeightInVectorTileC = MatShape.ROW_C / MatShape.ElementCountPerVectorMatC; +} + +// A wrapper of matrix type to access each tile of the matrix +// Note: The reason we want to make SubgroupSize as a generic parameter instead of using WaveGetLaneCount() is that +// we need the logic to be expand as much as possible during compile time. WaveGetLaneCount() is a runtime function, +// so if we use it directly in our mma function, some of branches won't be eliminated during compile time. +[require(cooperative_matrix)] +VISIBILITY_LEVEL struct MMAHelper + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + typealias TileInfo = TileInfo; + typealias CMShape = CoopMatShape; + typealias ShMemInfo = SharedMemoryUsage; + + static const linalg.CoopMatMatrixLayout RowMajor = linalg.CoopMatMatrixLayout.RowMajor; + static const linalg.CoopMatMatrixLayout ColumnMajor = linalg.CoopMatMatrixLayout.ColumnMajor; + + // Weight matrix is in shape of M x K + // Input vector will be collected by each thread in a sub-group to form a matrix in shape of K x N + // Note, M, K, N are measured in T. + // Since Cooperative matrix is only worked for a whole subgroup (warp), so we always use the SubgroupSize as the N value. + static const int M = OutputSize; + static const int N = SubgroupSize; + static const int K = InputSize; + + VISIBILITY_LEVEL static const int Uint4AlignedK = TileInfo.Uint4AlignedK; + VISIBILITY_LEVEL static const int Uint4AlignedM = TileInfo.Uint4AlignedM; + + // NumLoadsEachThreadMatA is the number of loads each thread needs to perform. + // N is the number of active threads in the subgroup. + // For matrix A, each iteration we load a column of matrix tile of Matrix A, which is height * sizeof(elementType) x 16 bytes. + // And each thread can load 16 bytes for each load, so the number of loads each thread needs to do for half type is + // (M * 32) / (N * 16) = M * 2 / N, and we just round up the result to the nearest integer. + static const uint NumLoadsEachThreadMatA = (TileInfo.HeightInElementsTileA * 2 + N - 1) / N; + + static const uint NumStoresEachThreadMatA = (sizeof(T) == 2) ? + NumLoadsEachThreadMatA : + (TileInfo.HeightInElementsTileA * 4 + N - 1) / N; + + // OffsetPerThreadLoadMatA is the offset between each load. + // The reason is same as above, since the whole subgroup can load N * 16 bytes, and the load is consecutive, + // so the offset is the N * 16 bytes, and we measure it in uint4 units, so it's N. + static const uint OffsetPerThreadLoadMatA = N; + + // A/B can only be half, so we don't need to use generic parameter here. + typealias MatA = linalg.CoopMat; + typealias MatB = linalg.CoopMat; + typealias MatC = linalg.CoopMat; + + // TODO: This will always double the buffer usage, therefore we haven't really implemented the double buffer yet. + static const bool EnableDoubleBuffer = false; + + [ForceInline] + static uint flattenedIndex(uint2 coord) + { + if (TransposeA) + { + return coord.y + coord.x * Stride; + } + else + { + return coord.x + coord.y * Stride; + } + } + + [ForceInline] + static uint2 vectorizedTileCoordToWeightCoord(uint2 coordInTile, int tileIndex) + { + uint2 coordOut; + if (TransposeA) + { + coordOut.x = coordInTile.x + tileIndex * TileInfo.WidthInElementTileA; + coordOut.y = coordInTile.y * ElemCountPerVector; + } + else + { + coordOut.x = coordInTile.x * ElemCountPerVector + tileIndex * TileInfo.WidthInElementTileA; + coordOut.y = coordInTile.y; + } + return coordOut; + } + + [ForceInline] + static uint2 vectorizedTileIndexTo2DCoord(uint indexInVectorizedTile) + { + uint2 coordOut; + if (TransposeA) + { + coordOut.x = indexInVectorizedTile / HeightInVector; + coordOut.y = indexInVectorizedTile % HeightInVector; + } + else + { + coordOut.x = indexInVectorizedTile % WidthInVector; + coordOut.y = indexInVectorizedTile / WidthInVector; + } + return coordOut; + } + // Load a tiled column of the Matrix A into shared memory. + // We load a column of matrix tile of Matrix A each time, because our iteration is over K dimension. + // We require the Matrix A is stored as row major in the weight storage. + // Also we load the tile into shared memory in row major. + // For each workgroup (thread block), we only use the first Subgroup (warp) to load the data. + // This is just balanced decision between simplicity and bank conflict. Given that the shared memory tile + // is not that large, if the warp is full, then each thread only needs to load 2 twice, therefore there is + // no need to use more than one warp. And more warp could also make the algorithm avoiding bank conflict more + // complicated. + // So implement as this for now. We can profiling the application later, if it turns out to be a bottleneck, + // we can consider using more than one warp in the future. + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + internal static void matALoadStore( + SPtr sharedMemory, + int tileIndex, + Storage weightStorage, + Storage.Address weightAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + { + uint indexInVectorizedTile = WaveGetLaneIndex(); + + // In load case, since we always load the matrix into MatA, so the tile must in half type, because that is + // what MatA can only support. + // However, in store, we store the MatC which could be float type, so we need to adjust the tile height in vector + // and element count per vector. + static const int elementCountPerVector = isStore ? CMShape.ElementCountPerVectorMatC : CMShape.ElementCountPerVector; + static const int heightInVector = isStore ? + TileInfo.HeightInVectorTileA * CMShape.ElementCountPerVector / CMShape.ElementCountPerVectorMatC : + TileInfo.HeightInVectorTileA; + static const int widthInVector = isStore ? + TileInfo.WidthInVectorTileA * CMShape.ElementCountPerVector / CMShape.ElementCountPerVectorMatC : + TileInfo.WidthInVectorTileA; + static const int numLoadOrStorePerThread = isStore ? NumStoresEachThreadMatA : NumLoadsEachThreadMatA; + + // get the 2-D coordinate of the tile in the vectorized tile + const uint2 coordInTile = vectorizedTileIndexTo2DCoord(indexInVectorizedTile); + + // tx and ty are the 2-D coordinate of scalar tile. + // row coordinate ty is the same between vectorized tile and original tile. + // column coordinate tx needs to be scaled by ElementCountPerVector to get the actual column coordinate in original data type. + + uint2 coordInWeight = vectorizedTileCoordToWeightCoord(coordInTile, tileIndex); + + uint4 value; + uint indexInWeight; + for (uint i = 0; i < numLoadOrStorePerThread; i++) + { + if (coordInWeight.y >= TileInfo.HeightInElementsTileA) + break; + // Given the coordinate inside the tile and the tile index, we can get the index in weight matrix. + // Note, we always treat row of matrix A as uint4 aligned, so always calculate 'indexInWeight' as if it's uint4 aligned. + // though it might not. But IStorage.readUint4() will handle the padding. + // We only need to provide the actual boundary and the aligned boundary. + indexInWeight = flattenedIndex(coordInWeight); + + // Bounds check on global memory. + bool isOutOfRange = TransposeA ? + ((coordInWeight.y >= K) | (coordInWeight.x >= M)) : + ((coordInWeight.y >= M) | (coordInWeight.x >= K)); + + if (isStore == false) + { + if (isOutOfRange) + { + // If the coordinate is out of range of Matrix A, just padding with 0. + value = uint4(0, 0, 0, 0); + } + else + { + let offsetAddress = Storage.getOffset(weightAddress, indexInWeight); + // Though the weight matrix is not necessarily aligned with the uint4 vector size, readUnit4() will handle the padding + // if the reading is cross the boundary of the uint4 vector size. + // MatA can only be half type, so we can always read it as half type no matter what the data type of the weight matrix is. + value = weightStorage.readUint4(weightAddress, offsetAddress); + } + + sharedMemory[indexInVectorizedTile] = value; + } + else + { + // We don't need padding when store back to global memory, because the padding is only needed when construct the cooperative matrix. + // So if there is out of range, we don't need to store anything. + if (!isOutOfRange) + { + value = sharedMemory[indexInVectorizedTile]; + let offsetAddress = Storage.getOffset(weightAddress, indexInWeight); + + weightStorage.writeUint4Atomic(weightAddress, offsetAddress, value); + } + } + + indexInVectorizedTile += OffsetPerThreadLoadMatA; + + // Given the index in the tile, get the 2-D coordinate in the matrix A. + // Note that, when TransposeA is true, this is 2-D coordinate in the matrix A^T. We don't change the semantics of the 2-D coordinate + // in matrix, but we change the way to calculate the 1-D index in the matrix. But we can still use the same function to get the 2-D coordinate. + coordInWeight = vectorizedTileCoordToWeightCoord( + vectorizedTileIndexTo2DCoord(indexInVectorizedTile), + tileIndex); + } + } + + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static void loadShA( + SPtr sharedMemory, + int tileIndex, + Storage weightStorage, + Storage.Address weightAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + { + matALoadStore(sharedMemory, tileIndex, weightStorage, weightAddress); + } + + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static void storeShA( + SPtr sharedMemory, + int tileIndex, + Storage weightStorage, + Storage.Address weightAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + { + matALoadStore(sharedMemory, tileIndex, weightStorage, weightAddress); + } + + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + internal static void vectorLoadStore( + SPtr sharedMemoryB, + int tileIndex, + int subgroupIndex, + inout InputArray inoutVector) + where U : __BuiltinFloatingPointType + where InputArray : IArrayAccessor + { + const uint laneId = WaveGetLaneIndex(); + + // select the correct shared memory layout configurations. + // In store case, we actually store the MatC, whose data type could not be `half`. + const uint elementPerVector = Op != AccessOp.READ ? CMShape.ElementCountPerVectorMatC : CMShape.ElementCountPerVector; + const uint heightInVectorTile = Op != AccessOp.READ ? TileInfo.HeightInVectorTileC : TileInfo.HeightInVectorTileB; + const uint sharedMemSizeInVector = Op != AccessOp.READ ? ShMemInfo.SharedMemSizeInVectorMatC : ShMemInfo.SharedMemSizeInVectorMatB; + + // In store case, we actually store the MatC. So the shared memory size could be different from MatB. + SPtr sharedMemoryBSubgroup = sharedMemoryB + subgroupIndex * sharedMemSizeInVector; + + // start index in vectorized tile for current thread: x * column_stride + const uint indexInVectorizedTile = laneId * heightInVectorTile; + const uint yOffset = (tileIndex * heightInVectorTile) * elementPerVector; + + [ForceUnroll] + for (uint yInTile = 0; yInTile < heightInVectorTile; yInTile++) + { + const int startIndex = yInTile * elementPerVector + yOffset; + + // Bounds check on Y direction. If y coordinate out of the input vector length, just padding with 0. + // No need to check the X direction, because the X direction is bound by the thread count, so any active + // thread will definitely have its thread-local vector available. + + bool isOutOfRange; + + if (Op == AccessOp.READ) + { + // For vector load, we are performing MMA, so the input vector has length of K + // in normal case, and M in transpose case. + isOutOfRange = TransposeA ? (startIndex >= M) : (startIndex >= K); + if (isOutOfRange) + { + sharedMemoryBSubgroup[indexInVectorizedTile + yInTile] = uint4(0, 0, 0, 0); + continue; + } + + // It's fine that we only use the aligned version of readUint4() here, because the inputVector is the internal data which we can always + // create it as an aligned array. + uint4 value; + + // Note we can only use half type for matrix A and B, the value must be packed by 8 half type elements. + accessUint4Aligned(inoutVector, startIndex, value); + sharedMemoryBSubgroup[indexInVectorizedTile + yInTile] = value; + } + else + { + // for vector store, the length of output vector has length of M in normal case + // and K in transpose case. + isOutOfRange = TransposeA ? (startIndex >= K) : (startIndex >= M); + + // No padding needs for writing back to local vector. + if (isOutOfRange) + return; + + // in store case, we always store MatC, so the data type is T. + uint4 value = sharedMemoryBSubgroup[indexInVectorizedTile + yInTile]; + accessUint4Aligned(inoutVector, startIndex, value); + } + } +# if 0 + // TODO: In most case (on AMD and NVIDIA GPU), N (SubgroupSize) >= AlignedN (aligned to tile width of matrix B), as + // subgroup size is usually larger than the tile width of matrix B. + // But in some corner cases (e.g. intel GPU), where the hardware only support very small subgroup size, N could be less than AlignedN. + // Not sure if we want to support this case! + if (N < AlignedN) + { + // If the active thread count is less than the aligned N, we need to pad the remaining columns with 0. + const uint numPaddingColumns = AlignedN - N; + + // The number of vectors to load == numPaddingColumns * HeightInVectorTileB. + // The number of vectors to load per thread == (numPaddingColumns * HeightInVectorTileB / N). + const uint numLoadsEachThreadMatB = (numPaddingColumns * HeightInVectorTileB + N - 1) / N; + const uint offsetPerThreadLoadMatB = N; + + // Because we already load N columns, the starting index is N * HeightInVectorTileB. + SPtr paddingShmPtr = sharedMemoryBSubgroup + N * HeightInVectorTileB; + + uint index = laneId; + for (uint i = 0; i < numLoadsEachThreadMatB; i++) + { + const uint xIndex = index / HeightInVectorTileB; + if (xIndex >= numPaddingColumns) + break; + paddingShmPtr[index] = uint4(0, 0, 0, 0); + index += offsetPerThreadLoadMatB; + } + } +# endif + } + + // Load the input vector into shared memory. Each input vector is a column of the matrix B. + // Since the input vector is loaded from thread local memory, one thread can only load one + // column of the matrix B. + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static void loadVectorToShB( + SPtr sharedMemoryB, + int tileIndex, + int subgroupIndex, + InputArray inputVector) + where U : __BuiltinFloatingPointType + where InputArray : IArrayAccessor + { + vectorLoadStore(sharedMemoryB, tileIndex, subgroupIndex, inputVector); + } + + // Load the input vector into shared memory. Each input vector is a column of the matrix B. + // Since the input vector is loaded from thread local memory, one thread can only load one + // column of the matrix B. + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static void storeVectorFromShB( + SPtr sharedMemoryB, + int tileIndex, + int subgroupIndex, + inout InputArray inoutVector) + where U : __BuiltinFloatingPointType + where InputArray : IArrayAccessor + { + vectorLoadStore(sharedMemoryB, tileIndex, subgroupIndex, inoutVector); + } + + // Read one complete vector into the shared memory. Different from the loadVectorToShB, + // this function is used to load the input vector into the shared memory for the transpose case, where each + // tile will not contain the all the vectors in a warp. Instead, only partial warp will load the complete vector + // into the shared memory. Compared to the LoadSharedMemoryFromLocalVector, where it loads partial vectors by every thread + // in a warp. + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static void loadVectorForOuterProduct( + SPtr sharedMemory, + int tileIndex, + InputArray inputVector) + where U : __BuiltinFloatingPointType + where U.Differential == U + where InputArray : IArrayAccessor + { + const uint laneId = WaveGetLaneIndex(); + + // Due the limitation of the size of shared memory, we cannot load all the vectors in a warp to the shared memory. Instead, + // we can only load it in smaller batches (batch size is the WMMA tile in column or row). + if ((laneId / BatchSize) != tileIndex) + return; + + uint index = (laneId % BatchSize) * TileStrideInVector; + + for (uint i = 0; i < TileStrideInVector; i++) + { + uint4 value; + if (i >= ArraySizeInVector) + { + value = uint4(0, 0, 0, 0); + } + else + { + accessUint4Aligned(inputVector, i * CMShape.ElementCountPerVector, value); + } + sharedMemory[index + i] = value; + } + } + + [require(cooperative_matrix, subgroup_basic)] + internal static void sumReduceTiles(inout MatC matC[ROW*COLUMN], in int subgroupId, SPtr sharedMemory) + { + const int subgroupCount = getWaveCount(); + + // perform n/2 sum reduces, so each iteration will perform a 2-way reduce operation, + // totally log2(n) steps, where n is the number of subgroups. + for (int k = 1; k < subgroupCount; k <<= 1) + { + // (subgroupId / k<< 1 ) calculate which cooperative matrix tile this warp will write to. + // The whole shared memory is divided into subgroupCount/2 parts. So each two neighboring subgroups + // will use one shared memory. And left subgroup loads, right subgroup stores. + uint subgroupOffset = (subgroupId / (k<<1)) * (CMShape.CoopMatCSizeInVector * COLUMN); + let ptrPerWarp = sharedMemory + subgroupOffset; + for (int i = 0; i < ROW; i++) + { + // Store one row of cooperative matrix to the shared memory. + for (int j = 0; j < COLUMN; j++) + { + // This is the right node in the 2-way merge operation. + if (subgroupId % (k<<1) == k) + { + // We choose to process one row of cooperative matrices at a time, store the cooperative matrix in column major is more efficient. + // Because in column major, each column of the cooperative matrix is contiguous in the shared memory, so there will not be bank conflict. + matC[i * COLUMN + j].Store(ptrPerWarp + j * CMShape.CoopMatCSizeInVector, + CMShape.ROW_C / CMShape.ElementCountPerVectorMatC); + } + } + + // wait for all subgroups finishing store the cooperative matrix to the shared memory. + GroupMemoryBarrierWithGroupSync(); + + // This the left node in the 2-way merge operation. + if (subgroupId % (k<<1) == 0) + { + // If the left node is the last subgroup, it means that there is no right node for it, so we don't need to do any addition for this subgroup. + if (subgroupId != (subgroupCount - 1)) + { + for (int j = 0; j < COLUMN; j++) + { + MatC rightMatC = MatC.Load(ptrPerWarp + j * CMShape.CoopMatCSizeInVector, + CMShape.ROW_C / CMShape.ElementCountPerVectorMatC); + + matC[i * COLUMN + j] = matC[i * COLUMN + j] + rightMatC; + } + } + } + // wait for all subgroups finishing the addition. + GroupMemoryBarrierWithGroupSync(); + } + } + } + + // Sum reduce of the dOut + [ForceInline] + [require(cuda_spirv, subgroup_partitioned)] + [require(cuda_spirv, sm_6_6)] + VISIBILITY_LEVEL static void sumReduceRows( + SPtr sharedMemory, + in InArrayType dOut, + in int subgroupId, + Storage biasStorage, + Storage.Address biasAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + where InArrayType : IArrayAccessor + { + const uint laneId = WaveGetLaneIndex(); + const int subgroupCount = getWaveCount(); + + // First, perform warp reduce, dOut's actual size is M + InArrayType localResult; + for (int i = 0; i < M; i++) + { + localResult[i] = WaveActiveSum(dOut[i]); + } + + // TODO: + // Naive way is that we still have to store to the shared memory by tile size, because + // we are reuse the shared memory B for the shared memory saving sake. However, more smart + // way is that we can compute the usage here during compile time, and compare with the total + // size of the shared memory pool. Since this sum reduce compute is not overlapped with other + // computation, so we can actually freely use all the shared memory in the pool. + static const int elementPerVector = sizeof(uint4) / sizeof(U); + static const int NumTiles = (Uint4AlignedM + TileInfo.HeightInElementTileB - 1) / TileInfo.HeightInElementTileB; + + // This loop is used to move the local result of each warp into the first warp + [ForceUnroll] + for (uint i = 0; i < NumTiles; i++) + { + const int startIndex = i * TileInfo.HeightInElementTileB; + + // lane 0 of each warp will store the local result to the shared memory. + if (laneId == 0) + { + [ForceUnroll] + for (uint j = 0; j < TileInfo.HeightInVectorTileB; j++) + { + uint4 value; + const uint offset = startIndex + j * elementPerVector; + if (offset >= M) + break; + accessUint4Aligned(localResult, offset, value); + sharedMemory[subgroupId * TileInfo.HeightInVectorTileB + j] = value; + } + } + + GroupMemoryBarrierWithGroupSync(); + + // Now all the results are stored in the shared memory, we can perform warp reduce to get the final result. + if (subgroupId == 0) + { + if (laneId < subgroupCount) + { + [ForceUnroll] + for (uint j = 0; j < TileInfo.HeightInVectorTileB; j++) + { + uint4 value = sharedMemory[laneId * TileInfo.HeightInVectorTileB + j]; + int offset = startIndex + j * elementPerVector; + if (offset >= M) + break; + accessUint4Aligned(localResult, offset, value); + } + } + } + } + + // Now all the partial result are stored in the registers in the first warp, so we can perform the one more + // warp reduce to get the final result. + if (subgroupId == 0) + { + // Only perform the reduce for the lane ids smaller than the 'subgroupCount', because there + // are only 'subgroupCount' vectors in the shared memory. + WaveMask mask = WaveMaskBallot(0xFFFFFFFF, laneId < subgroupCount); + [ForceUnroll] + for (uint j = 0; j < M; j++) + { + localResult[j] = WaveMaskSum(mask, localResult[j]); + } + + // spread the result writing to all threads in the first subgroup. + // Atomic add to the bias gradient storage buffer + static const int numIter = (M + SubgroupSize - 1) / SubgroupSize; + [ForceUnroll] + for (uint i = 0; i < numIter; i++) + { + const int index = i * SubgroupSize + laneId; + if (index >= M) + break; + + let offset = Storage.getOffset(biasAddress, index); + biasStorage.atomicAdd(offset, localResult[index]); + } + } + } + + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + [require(spvAtomicFloat16AddEXT)] + VISIBILITY_LEVEL static void outerProductAccumulate( + SPtr sharedMemoryA, + SPtr sharedMemoryB, + InArrayTypeA inputVectorA, + InArrayTypeB inputVectorB, + Storage weightStorage, + Storage.Address weightAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + where TypeA : __BuiltinFloatingPointType + where TypeA.Differential == TypeA + where TypeB : __BuiltinFloatingPointType + where TypeB.Differential == TypeB + where InArrayTypeA : IArrayAccessor + where InArrayTypeB : IArrayAccessor + { + const uint subgroupIndex = getWaveId(); + + // Batch size is either CoopMatA row count or CoopMatB column count, they are the same. + static const int BatchSize = CMShape.COLUMN_A; + static const int CoopMatCRows = (M + CMShape.ROW_A - 1) / CMShape.ROW_A; + static const int CoopMatCColumns = (K + CMShape.COLUMN_A - 1) / CMShape.COLUMN_A; + static const int StrideInVectorTileB = CoopMatCColumns * CMShape.COLUMN_B / CMShape.ElementCountPerVector; + static const int StrideInVectorTileA = CoopMatCRows * CMShape.ROW_A / CMShape.ElementCountPerVector; + static const int StrideInVectorTileC = CoopMatCColumns * CMShape.COLUMN_C / CMShape.ElementCountPerVectorMatC; + + MatC matC[CoopMatCRows * CoopMatCColumns]; + for (int i = 0; i < CoopMatCRows * CoopMatCColumns; i++) + { + matC[i].fill(T(0.0f)); + } + + SPtr warpLocalPtrA = sharedMemoryA + subgroupIndex * StrideInVectorTileA * BatchSize; + SPtr warpLocalPtrB = sharedMemoryB + subgroupIndex * StrideInVectorTileB * BatchSize; + + // Perform outer product for each warp. + // The shared dimension is the Subgroup size N. + for (uint k = 0; k < N; k += CMShape.COLUMN_A) + { + uint tileIndex = k / CMShape.COLUMN_A; + loadVectorForOuterProduct(warpLocalPtrA, tileIndex, inputVectorA); + loadVectorForOuterProduct(warpLocalPtrB, tileIndex, inputVectorB); + GroupMemoryBarrierWithWaveSync(); + + MatA matA[CoopMatCRows]; + for (uint i = 0; i < CoopMatCRows; i++) + { + matA[i] = MatA.Load( warpLocalPtrA + i * CMShape.ROW_A / CMShape.ElementCountPerVector, StrideInVectorTileA); + } + + for (uint j = 0; j < CoopMatCColumns; j++) + { + // For matrix B, stride will be the width of the tile. This is different from the forward pass in A*B where the stride is height of tile B. + // Since width of the tile B in outer product is K (input vector length), so the stride will be alignedK / ElementCountPerVector. + let matB = MatB.Load(warpLocalPtrB + j * CMShape.COLUMN_B / CMShape.ElementCountPerVector, StrideInVectorTileB); + for (uint i = 0; i < CoopMatCRows; i++) + { + int index = i * CoopMatCColumns + j; + matC[index] = linalg.coopMatMulAdd(matA[i], matB, matC[index]); + } + } + } + + // sum reduce will accumulate the result cross all the warps into the cooperative matrix in the first subgroup. + sumReduceTiles(matC, subgroupIndex, sharedMemoryB); + + // Only the first subgroup need to store the result to the global memory. + + if (subgroupIndex == 0) + { + // Store the result shared memory. We store one row of the cooperative matrix at a time. + for (uint i = 0; i < CoopMatCRows; i++) + { + for (uint j = 0; j < CoopMatCColumns; j++) + { + // Save the result to shared memory in row major, because this is cache friendly for global memory access. + // As each row of the tile will be corresponding to the contiguous global memory. + matC[i * CoopMatCColumns + j].Store(sharedMemoryB + j * CMShape.COLUMN_C / CMShape.ElementCountPerVectorMatC, (uint)StrideInVectorTileC); + + } + + // We will leverage the fact that the result the outerproduct is the same shape as matrix A. But, the tile organization is different. + // In forward pass, the tile in the shared memory is in shape M x T_ROW_A, which is one column of cooperative matrices. + // while the result here is in shape T_ROW_A x K, which is one row of cooperative matrices. + // So we can use transposed version of LoadShA/StoreShA to store the result. Because in transpose version, the major-ness is along + // K dimension. Note that column major in transpose case is just the row major in non-transpose case. + + typealias MMA = MMAHelper; + MMA.storeShA(sharedMemoryB, i, weightStorage, weightAddress); + + // Wait until all threads in the warp finish reading the shared memory and storing to the global memory. + GroupMemoryBarrierWithWaveSync(); + } + } + + AllMemoryBarrierWithGroupSync(); + } + + // The workload is different from the large size matrix-matrix multiply, we actually perform the matrix-vector multiply + // for each thread, so the matrix-matrix multiply only needs to be performed for each warp (sub-group). + [ForceInline] + [require(cooperative_matrix, subgroup_basic)] + VISIBILITY_LEVEL static OutArrayType mma(InArrayType inputVector, + SPtr sharedMemoryA, + SPtr sharedMemoryB, + SPtr sharedMemoryC, + Storage weightStorage, + Storage.Address weightAddress, + Optional biasStorage, + Optional biasAddress) + where U : __BuiltinFloatingPointType + where U.Differential == U + where Storage : IStorage + where V : __BuiltinFloatingPointType + where V.Differential == V + where InArrayType : IArrayAccessor + where OutArrayType : IArrayAccessor + { + SPtr ptrA[EnableDoubleBuffer ? 2 : 1]; + SPtr ptrB[EnableDoubleBuffer ? 2 : 1]; + SPtr ptrC = sharedMemoryC; + const uint subgroupIndex = getWaveId(); + + if (!EnableDoubleBuffer) + { + ptrA[0] = sharedMemoryA; + ptrB[0] = sharedMemoryB; + } + else + { + ptrA[0] = sharedMemoryA; + ptrA[1] = sharedMemoryA + ShMemInfo.SharedMemSizeInVectorMatA; + ptrB[0] = sharedMemoryB; + ptrB[1] = sharedMemoryB + ShMemInfo.SharedMemSizeInVectorMatB; + // PreLoad for the first iteration + loadShA(ptrA[0], 0, weightStorage, weightAddress); + loadVectorToShB(ptrB[0], 0, subgroupIndex, inputVector); + } + + // fetch first tile of Matrix A and Matrix B into shared memory + // Iterate over the K dimension, T.COLUMN is column of Matrix A. + uint bufferIndex = 0; + MatC matC[TileInfo.NCoopMat]; + [ForceUnroll] + for (int i = 0; i < TileInfo.NCoopMat; i++) + { + matC[i].fill(T(0.0f)); + } + + OutArrayType outputVector; + if (Bias) + { + [ForceUnroll] + for (int i = 0; i < M; i++) + { + Storage.Address offset = Storage.getOffset(biasAddress.value, i); + U bias = biasStorage.value.read(offset); + outputVector[i] = __realCast(bias); + } + } + + for (int k = 0; k < TileInfo.SharedDimensionSize; k += CMShape.COLUMN_A) + { + uint tileIndex = k / CMShape.COLUMN_A; + + if (EnableDoubleBuffer) + { + GroupMemoryBarrierWithGroupSync(); + + // swap to next buffer to load new data. + bufferIndex = bufferIndex ^ 1; + // Load the another buffer and do the math so we can hide the latency of the load. + loadShA(ptrA[bufferIndex], tileIndex, weightStorage, weightAddress); + loadVectorToShB(ptrB[bufferIndex], tileIndex, subgroupIndex, inputVector); + + // swap back to the previous buffer to perform the math. + bufferIndex = bufferIndex ^ 1; + } + else + { + loadShA(ptrA[0], tileIndex, weightStorage, weightAddress); + loadVectorToShB(ptrB[0], tileIndex, subgroupIndex, inputVector); + + // The reason that we have to sync the whole workgroup is that for matrix A, the tile is shared by + // whole workgroup. + + // TODO: + // The alternative solution to duplicate the tile A to each subgroup, so we will only need warp sync here. + // But that will waste lots of shared memory and increase memory transactions. We need benchmark to see + // which solution is better. + GroupMemoryBarrierWithGroupSync(); + } + + // Math loop: This operator is executed by each warp (sub-group). shA is shared by all threads in the workgroup, while + // shB is shared only by each subgroup, and each subgroup will have its own offset on shB. + // For matB, each warp could only have 1 or 2 Tiles to load according whether this warp is less (1 tile) or more than half warp (2 tiles). + MatA matA[TileInfo.NCoopMatRow]; + for (uint i = 0; i < TileInfo.NCoopMatRow; i ++) + { + if (TransposeA) + { + // Really be careful the stride of tile A in the transpose case. The tile is column major, and the cooperative matrix is stacked in columns of the tile. + // So the stride is actually the height the tile. The offset between two cooperative matrices is just height of the cooperative matrix. + matA[i] = MatA.Load(ptrA[bufferIndex] + i * CMShape.ROW_A/CMShape.ElementCountPerVector, TileInfo.HeightInVectorTileA); + } + else + { + matA[i] = MatA.Load(ptrA[bufferIndex] + i * CMShape.CoopMatASizeInVector, TileInfo.WidthInVectorTileA); + } + } + + for (uint j = 0; j < TileInfo.NCoopMatColumn; j ++) + { + // NTilesColumn * HeightInVectorTileB is the size of the shared memory for matrix B in one warp measured in vector uint4. + SPtr ptrPerWarp = ptrB[bufferIndex] + subgroupIndex * ShMemInfo.SharedMemSizeInVectorMatB; + let matB = MatB.Load(ptrPerWarp + j * CMShape.CoopMatBSizeInVector, TileInfo.HeightInVectorTileB); + + for (uint i = 0; i < TileInfo.NCoopMatRow; i ++) + { + int index = i * TileInfo.NCoopMatColumn + j; + matC[index] = linalg.coopMatMulAdd(matA[i], matB, matC[index]); + } + } + } + + // Write back the result to shared memory B. We store the result in column major because each column is + // actually the output vector of the matrix-vector multiply, which is also a thread-local vector. + + // Get start address of the shared memory for current warp. + SPtr ptrPerWarp = ptrC + subgroupIndex * ShMemInfo.SharedMemSizeInVectorMatC; + + [ForceUnroll] + for (int i = 0; i < TileInfo.NCoopMatRow; i ++) + { + [ForceUnroll] + for (int j = 0; j < TileInfo.NCoopMatColumn; j ++) + { + int index = i * TileInfo.NCoopMatColumn + j; + matC[index].Store(ptrPerWarp + j * CMShape.CoopMatCSizeInVector, + TileInfo.HeightInVectorTileC); + } + + if (Bias) + { + storeVectorFromShB(ptrC, i, subgroupIndex, outputVector); + } + else + { + storeVectorFromShB(ptrC, i, subgroupIndex, outputVector); + } + + // wati until all threads in the warp finish reading shared memory and storing to local vector. + GroupMemoryBarrierWithWaveSync(); + } + + return outputVector; + } +} + +// Cooperative matrix is only supported by CUDA and SPIR-V +// WaveTangledVector is a vector type that emulates the Cooperative Vector type by using Cooperative Matrix feature which is +// supported by CUDA and SPIR-V. +[require(cooperative_matrix, subgroup_basic)] +public struct WaveTangledVector : IVector + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + public typealias Differential = WaveTangledVector; + + public static const int Size = N; + public no_diff ShMemPool shMemPool; + + private typealias DTypeMatC = half; + + // TODO: This is just an easiest solution to make the alignment work. But the disadvantage is wasting of registers. + // A better solution is that we can implement a "special" alignment rule for the vectorized reader. + // Normally, the alignment rule is that one uint4 can pack 8 half type elements, so we want our internal data + // is always 8 elements aligned. But we can have a "byte-aligned" rule such that we just need our data is 16 bytes aligned. + // So for example, float[4] is also considered as aligned. We can just pack 4 elements into a uint4, the remaining 4 elements + // just fill with 0. + // Since we always use half type for A and B, we want to follow the alignment requirement of half type. + internal static const uint ElementCountPerVector = sizeof(uint4) / sizeof(DTypeMatC); + internal static const int Uint4AlignedInputSize = ((N + (ElementCountPerVector - 1)) / ElementCountPerVector) * ElementCountPerVector; + + // [DerivativeMember(Differential.data)] + internal T[Uint4AlignedInputSize] data; + public int getCount() { return N; } + + public __init() { data = {}; } + + public __init(T value) + { + [ForceUnroll] + for (int i = 0; i < N; i++) + { + this.data[i] = value; + } + + cleanPaddingData(); + } + + public __init(T[Size] inData) + { + [ForceUnroll] + for (int i = 0; i < N; i++) + { + this.data[i] = inData[i]; + } + + cleanPaddingData(); + } + + internal __init(T[Uint4AlignedInputSize] inData) + { + this.data = inData; + } + + public __init(This other) { this.data = other.data; } + + public __init>(InputArray inData) + { + [ForceUnroll] + for (int i = 0; i < N; i++) + this.data[i] = inData[i]; + + cleanPaddingData(); + } + + [ForceInline] + [mutating] + internal void cleanPaddingData() + { + const int elementCountPerUint4 = sizeof(uint4) / sizeof(DTypeMatC); + const int alignedSize = ((N + (elementCountPerUint4 - 1)) / elementCountPerUint4) * elementCountPerUint4; + + [ForceUnroll] + for (int i = N; i < alignedSize; i++) + { + this.data[i] = T(0.0f); + } + } + + public __subscript(int index) -> T + { + [ForceInline] + get { return this.data[index]; } + + [ForceInline] + set { this.data[index] = newValue; } + } + + private OutputVector linearTransformOnTarget( + Storage weight, + no_diff Storage.Address weightAddress, + no_diff Optional bias, + no_diff Optional biasAddress) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where OutputVector : IVector + { + typealias MMA = MMAHelper; + SPtr shA = shMemPool.getPointer(); + SPtr shB = shA + MMA.ShMemInfo.SharedMemSizeInVectorMatA; + SPtr shC = shB; + const int AlignedOutSize = MMA.TileInfo.Uint4AlignedM; + let outputArray = MMA.mma( data, shA, shB, shC, weight, weightAddress, bias, biasAddress); + return OutputVector(outputArray); + } + + private static void linearTransformBwdOnTarget( + inout DifferentialPair dthis, + DifferentialPtrPair dWeightStorage, + no_diff Storage.Address dWeightAddress, + Optional> biasStorage, + no_diff Optional biasAddress, + OutputVector.Differential doutput) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where OutputVector : IVector + where OutputVector.Differential : IVector + { + typealias MMA = MMAHelper; + + SPtr shA = dthis.p.shMemPool.getPointer(); + SPtr shB = shA + MMA.ShMemInfo.SharedMemSizeInVectorMatA; + SPtr shC = shB; + + // In backward, output size is K dimension. + const int AlignedOutSize = MMA.TileInfo.Uint4AlignedK; + const int AlignedInputSize = MMA.TileInfo.Uint4AlignedM; + + // This.Differential is the derivative of the input vector, which is the output + // of the mma operation. + // dIn = W^T * dOut; +#if 1 + let outArray = MMA.mma( doutput, shA, shB, shC, dWeightStorage.p, dWeightAddress, none, none); + This.Differential dInput = This.Differential(outArray); + dthis = DifferentialPair(dthis.p, dInput); + + // dW = dOut * input^T + MMA.outerProductAccumulate( shA, shB, doutput, dthis.p.data, dWeightStorage.d, dWeightAddress); +#else + let dInput = MMA.mma( doutput, shA, shB, shC, dWeightStorage.p, dWeightAddress); + dthis = DifferentialPair(dthis.p, dInput); + + MMA.outerProductAccumulate< T.Differential, Storage.Differential, + T.Differential, OutputVector.Differential, + T, This + >( shA, shB, doutput, dthis.p, dWeightStorage.d, dWeightAddress); +#endif + + if (Bias) + { + // VISIBILITY_LEVEL static void sumReduceRows( + // SPtr sharedMemory, + // in InArrayType dOut, + // in int subgroupId, + // Storage biasStorage, + // Storage.Address biasAddress) + const int subgroupIndex = getWaveId(); + MMA.sumReduceRows(shB, doutput, subgroupIndex, biasStorage.value.d, biasAddress.value); + } + } + + // Linear transformation without bias + [Differentiable] + [BackwardDerivative(linearTransformBwd)] + public OutputVector linearTransform( + Storage weightStorage, + no_diff Storage.Address weightAddress) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where Layout : IStorageLayout + where OutputVector : IVector + { + __target_switch + { + case cuda: + return no_diff linearTransformOnTarget(weightStorage, weightAddress, none, none); + case spirv: + return no_diff linearTransformOnTarget(weightStorage, weightAddress, none, none); + } + } + + // Backward of linear transformation without bias + static void linearTransformBwd( + inout DifferentialPair dthis, + DifferentialPtrPair dWeightStorage, + no_diff Storage.Address dWeightAddress, + OutputVector.Differential doutput) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where Layout : IStorageLayout + where OutputVector : IVector + where OutputVector.Differential : IVector + { + Optional> biasStorage = {}; + __target_switch + { + case cuda: + linearTransformBwdOnTarget(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput); + case spirv: + linearTransformBwdOnTarget(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput); + } + } + + [Differentiable] + [BackwardDerivative(linearTransformBwd)] + public OutputVector linearTransform( + Storage weightStorage, + Storage biasStorage, + no_diff Storage.Address weightAddress, + no_diff Storage.Address biasAddress) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where Layout : IStorageLayout + where OutputVector : IVector + { + __target_switch + { + case cuda: + return no_diff linearTransformOnTarget(weightStorage, weightAddress, biasStorage, biasAddress); + case spirv: + return no_diff linearTransformOnTarget(weightStorage, weightAddress, biasStorage, biasAddress); + } + } + + // Backward of linear transformation with bias + static void linearTransformBwd( + inout DifferentialPair dthis, + DifferentialPtrPair dWeightStorage, + DifferentialPtrPair dBiasStorage, + no_diff Storage.Address dWeightAddress, + no_diff Storage.Address dBiasAddress, + OutputVector.Differential doutput) + where Storage : IStorage + where Storage.Differential : IStorage + where Storage.Address == Storage.Differential.Address + where Layout : IStorageLayout + where OutputVector : IVector + where OutputVector.Differential : IVector + { + __target_switch + { + case cuda: + linearTransformBwdOnTarget( dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput); + case spirv: + linearTransformBwdOnTarget(dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput); + } + } + + [Differentiable] + public OutputVector linearTransform( + Address weightAddress) + where Address : IPointerLikeAddress + where Address.Differential : IPointerLikeAddress + where Layout : IStorageLayout + where OutputVector : IVector + { + OutputVector output = OutputVector(); + return output; + } + + [Differentiable] + public OutputVector linearTransform( + Address weightAddress, Address biasAddress) + where Address : IPointerLikeAddress + where Address.Differential : IPointerLikeAddress + where Layout : IStorageLayout + where OutputVector : IVector + { + OutputVector output = OutputVector(); + return output; + } +} diff --git a/source/standard-modules/neural/bindless-storage.slang b/source/standard-modules/neural/bindless-storage.slang index 14d2ea0975..bda1038a21 100644 --- a/source/standard-modules/neural/bindless-storage.slang +++ b/source/standard-modules/neural/bindless-storage.slang @@ -128,7 +128,22 @@ public struct BindlessBufferStorage : IStorage public typealias Differential = BindlessBufferStorage; // Following method will not be needed for bindless storage - public T read(Address address) {return address[0];} + public T read(Address address) { return address[0]; } + + internal uint4 readUint4(Address baseAddress, Address address) + where DstType : __BuiltinFloatingPointType + where DstType.Differential == DstType + { + static_assert(false, "Not implemented"); + return uint4(0); + } + + internal void writeUint4Atomic(Address baseAddress, Address address, uint4 value) + where SrcType : __BuiltinFloatingPointType + where SrcType.Differential == SrcType + { + static_assert(false, "Not implemented"); + } public void atomicAdd(Address address, T value) {address.atomicAdd(0, value);} public void write(Address address, T value) {address[0] = value;} public static Address getOffset(Address base, int elements) { return base.getOffset(elements); } @@ -142,7 +157,21 @@ public struct PointerStorage : IStorage public typealias Differential = PointerStorage; // Following method will not be needed for pointer storage - public T read(Address address) {return address[0];} + public T read(Address address) { return address[0]; } + + internal uint4 readUint4(Address baseAddress, Address address) + where DstType : __BuiltinFloatingPointType + where DstType.Differential == DstType + { + static_assert(false, "Not implemented"); + return uint4(0); + } + internal void writeUint4Atomic(Address baseAddress, Address address, uint4 value) + where SrcType : __BuiltinFloatingPointType + where SrcType.Differential == SrcType + { + static_assert(false, "Not implemented"); + } public void atomicAdd(Address address, T value) {address.atomicAdd(0, value);} public void write(Address address, T value) {address[0] = value;} public static Address getOffset(Address base, int elements) {return base.getOffset(elements);} diff --git a/source/standard-modules/neural/buffer-storage.slang b/source/standard-modules/neural/buffer-storage.slang index f1d2a7de9f..09225953a3 100644 --- a/source/standard-modules/neural/buffer-storage.slang +++ b/source/standard-modules/neural/buffer-storage.slang @@ -10,6 +10,7 @@ Supports offset-based addressing and atomic operations for gradient accumulation - `T.Differential` must conform to `__BuiltinFloatingPointType` for automatic differentiation @category neural */ + public struct StructuredBufferStorage : IStorage where T : __BuiltinFloatingPointType where T.Differential == T @@ -17,7 +18,7 @@ public struct StructuredBufferStorage : IStorage /// Address type is a simple unsigned integer index. public typealias Address = uint; - /// The underlying buffer type. + /// The underlying buffer type.s public typealias BufferType = RWStructuredBuffer; /// Differential type for automatic differentiation. @@ -49,16 +50,37 @@ public struct StructuredBufferStorage : IStorage @param[in] address The address relative to base address. @return The value at the specified address. */ + + [ForceInline] public T read(Address address) { return m_buffer[address + m_baseAddress]; } + [ForceInline] + internal uint4 readUint4(Address baseAddress, Address address) + where DstType : __BuiltinFloatingPointType + where DstType.Differential == DstType + { + uint4 value; + accessUint4(m_buffer, baseAddress, address, value); + return value; + } + + [ForceInline] + internal void writeUint4Atomic(Address baseAddress, Address address, uint4 value) + where SrcType : __BuiltinFloatingPointType + where SrcType.Differential == SrcType + { + accessUint4(m_buffer, baseAddress, address, value); + } + /** Atomically adds a value (for gradient accumulation). @param[in] address The address relative to base address. @param[in] value The value to add. */ + [ForceInline] [require(cuda_glsl_hlsl_metal_spirv, sm_6_6)] public void atomicAdd(Address address, T value) { @@ -99,6 +121,7 @@ public struct StructuredBufferStorage : IStorage @param[in] address The address relative to base address. @param[in] value The value to write. */ + [ForceInline] public void write(Address address, T value) { m_buffer[address + m_baseAddress] = value; diff --git a/source/standard-modules/neural/inline-vector.slang b/source/standard-modules/neural/inline-vector.slang index caed6873f9..6a8a879aca 100644 --- a/source/standard-modules/neural/inline-vector.slang +++ b/source/standard-modules/neural/inline-vector.slang @@ -12,7 +12,7 @@ for gradient computation in neural networks. - `T.Differential` must conform to `__BuiltinFloatingPointType` for automatic differentiation @category neural */ -public struct InlineVector : IVector +public struct InlineVector : IVector where T : __BuiltinFloatingPointType where T.Differential == T { @@ -22,6 +22,7 @@ public struct InlineVector : IVector /// The compile-time size of the vector. public static const int Size = N; + public int getCount() {return N;} /** Internal storage for vector elements. @remarks Marked as derivative member to enable automatic differentiation. @@ -46,7 +47,7 @@ public struct InlineVector : IVector Array constructor - initializes from an array. @param[in] data Array of N elements to initialize the vector. */ - public __init(T[N] data) { this.data = data; } + public __init(T[Size] data) { this.data = data; } /** Copy constructor. @@ -54,6 +55,14 @@ public struct InlineVector : IVector */ public __init(This other) { this.data = other.data; } + public __init>(InputArray data) + { + static_assert(data.getCount() >= N, "The size of the input array must match the vector size"); + [ForceUnroll] + for (int i = 0; i < N; i++) + this.data[i] = data[i]; + } + /** Element access operator. @param[in] index The element index (0-based). @@ -67,18 +76,20 @@ public struct InlineVector : IVector // Linear transformation without bias [BackwardDerivative(linearTransformBwd)] - public OutputVector linearTransform( + public OutputVector linearTransform( Storage weightStorage, no_diff Storage.Address weightAddress) where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { OutputVector output = OutputVector(); + static const int outSize = OutputVector.Size; - [MaxIters(OutputSize)] - for (int row = 0; row < OutputSize; row++) + [MaxIters(outSize)] + for (int row = 0; row < outSize; row++) { // get the address of each row of the weight matrix let rowOffset = Storage.getOffset(weightAddress, row * N); @@ -95,7 +106,7 @@ public struct InlineVector : IVector // Linear transformation with bias [BackwardDerivative(linearTransformBwd)] - public OutputVector linearTransform( + public OutputVector linearTransform( Storage weightStorage, Storage biasStorage, no_diff Storage.Address weightAddress, @@ -103,14 +114,15 @@ public struct InlineVector : IVector where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { // Reuse the unbias matmul method - OutputVector output = this.linearTransform(weightStorage, weightAddress); + OutputVector output = this.linearTransform(weightStorage, weightAddress); // apply the bias [ForceUnroll] - for (int i = 0; i < OutputSize; i++) + for (int i = 0; i < OutputVector.Size; i++) { let elementOffset = Storage.getOffset(biasAddress, i); output[i] = output[i] + biasStorage.read(elementOffset); @@ -120,7 +132,7 @@ public struct InlineVector : IVector } // Backward of linear transformation without bias - static void linearTransformBwd( + static void linearTransformBwd( inout DifferentialPair dthis, DifferentialPtrPair dWeightStorage, no_diff Storage.Address dWeightAddress, @@ -128,20 +140,23 @@ public struct InlineVector : IVector where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { // Derivative of the input is transposed weight matrix times the output differential - var d = dthis.d; + This.Differential d = {}; - [MaxIters(OutputSize)] - for (int j = 0; j < OutputSize; j++) + [MaxIters(OutputVector.Size)] + for (int j = 0; j < OutputVector.Size; j++) { T.Differential dy = doutput[j]; [ForceUnroll] + // N is the column count of the weight matrix. for (int i = 0; i < N; i++) { - Storage.Address elementOffset = Storage.getOffset(dWeightAddress, i * OutputSize + j); + // On the transpose, we just perform linear combination of each row for the cache friendly. + Storage.Address elementOffset = Storage.getOffset(dWeightAddress, i + j * N); T.Differential prod = T.Differential.dmul(dWeightStorage.p.read(elementOffset), dy); d[i] = T.Differential.dadd(d[i], prod); } @@ -149,8 +164,8 @@ public struct InlineVector : IVector // Derivative of the weights is the outer product of the input and the output differential // dW = dOutput * dInput^T - [MaxIters(OutputSize)] - for (int row = 0; row < OutputSize; row++) + [MaxIters(OutputVector.Size)] + for (int row = 0; row < OutputVector.Size; row++) { let rowOffset = Storage.getOffset(dWeightAddress, row * N); T.Differential dy = doutput[row]; @@ -168,7 +183,7 @@ public struct InlineVector : IVector } // Backward of linear transformation with bias - static void linearTransformBwd( + static void linearTransformBwd( inout DifferentialPair dthis, DifferentialPtrPair dWeightStorage, DifferentialPtrPair dBiasStorage, @@ -178,14 +193,15 @@ public struct InlineVector : IVector where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { // Reuse the unbias backward method - linearTransformBwd(dthis, dWeightStorage, dWeightAddress, doutput); + linearTransformBwd(dthis, dWeightStorage, dWeightAddress, doutput); // Derivative of the bias is the same as the output differential [ForceUnroll] - for (int i = 0; i < OutputSize; i++) + for (int i = 0; i < OutputVector.Size; i++) { let biasOffset = Storage.getOffset(dBiasAddress, i); dBiasStorage.d.atomicAdd(biasOffset, doutput[i]); @@ -194,17 +210,18 @@ public struct InlineVector : IVector // Linear transformation without bias (Bindless storage) [BackwardDerivative(linearTransformBwd)] - public OutputVector linearTransform( + public OutputVector linearTransform( Address weightAddress) where Address : IPointerLikeAddress where Address.Differential : IPointerLikeAddress - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { var output = OutputVector(); // output = W * input - [MaxIters(OutputSize)] - for (int row = 0; row < OutputSize; row++) + [MaxIters(OutputVector.Size)] + for (int row = 0; row < OutputVector.Size; row++) { // get the address of each row of the weight matrix let rowAddr = weightAddress.getOffset(row * N); @@ -219,51 +236,53 @@ public struct InlineVector : IVector // Linear transformation with bias (Bindless storage) [BackwardDerivative(linearTransformBwd)] - public OutputVector linearTransform( + public OutputVector linearTransform( Address weightAddress, Address biasAddress) where Address : IPointerLikeAddress where Address.Differential : IPointerLikeAddress - where OutputVector : IVector + where Layout : IStorageLayout + where OutputVector : IVector { // Reuse the unbias matmul method - OutputVector output = this.linearTransform(weightAddress); + OutputVector output = this.linearTransform(weightAddress); [ForceUnroll] - for (int i = 0; i < OutputSize; i++) + for (int i = 0; i < OutputVector.Size; i++) output[i] = output[i] + biasAddress[i]; return output; } // Backward of linear transformation without bias (Bindless storage) - static public void linearTransformBwd( + static public void linearTransformBwd( inout DifferentialPair dthis, DifferentialPtrPair
dparameters, OutputVector.Differential doutput) - where Address : IPointerLikeAddress - where Address.Differential : IPointerLikeAddress - where OutputVector : IVector - where OutputVector.Differential : IVector + where Address : IPointerLikeAddress + where Address.Differential : IPointerLikeAddress + where Layout : IStorageLayout + where OutputVector : IVector + where OutputVector.Differential : IVector { // dInput = dW^T * dOutput - var d = dthis.d; - [MaxIters(OutputSize)] - for (int j = 0; j < OutputSize; j++) + This.Differential d = {}; + [MaxIters(OutputVector.Size)] + for (int j = 0; j < OutputVector.Size; j++) { let dy = doutput[j]; [ForceUnroll] for (int i = 0; i < N; i++) { - T.Differential prod = T.Differential.dmul(dparameters.p[i * OutputSize + j], dy); + T.Differential prod = T.Differential.dmul(dparameters.p[i + j * N], dy); d[i] = T.Differential.dadd(d[i], prod); } } // Derivative of the weights is the outer product of the input and the output differential - // dW = dOutput * dInput^T - [MaxIters(OutputSize)] - for (int row = 0; row < OutputSize; row++) + // dW = dOutput * Input^T + [MaxIters(OutputVector.Size)] + for (int row = 0; row < OutputVector.Size; row++) { let rowAddr = dparameters.d.getOffset(row * N); T.Differential dy = doutput[row]; @@ -280,22 +299,23 @@ public struct InlineVector : IVector } // Backward of linear transformation with bias (Bindless storage) - static public void linearTransformBwd( + static public void linearTransformBwd( inout DifferentialPair dthis, DifferentialPtrPair
dWeightAddress, DifferentialPtrPair
dBiasAddress, OutputVector.Differential doutput) - where Address : IPointerLikeAddress - where Address.Differential : IPointerLikeAddress - where OutputVector : IVector + where Address : IPointerLikeAddress + where Address.Differential : IPointerLikeAddress + where Layout : IStorageLayout + where OutputVector : IVector { // Reuse the unbias backward method - linearTransformBwd(dthis, dWeightAddress, doutput); + linearTransformBwd(dthis, dWeightAddress, doutput); let biasOffset = dBiasAddress.d.getOffset(0); // dBias = dOutput [ForceUnroll] - for (int i = 0; i < OutputSize; i++) + for (int i = 0; i < OutputVector.Size; i++) { biasOffset.atomicAdd(i, doutput[i]); } diff --git a/source/standard-modules/neural/istorages.slang b/source/standard-modules/neural/istorages.slang index cc55fe3ed1..e9344b601c 100644 --- a/source/standard-modules/neural/istorages.slang +++ b/source/standard-modules/neural/istorages.slang @@ -1,5 +1,20 @@ implementing neural; +public enum LayoutType : uint32_t +{ + Linear = 0, +} + +internal interface IStorageLayout +{ + internal static const LayoutType Layout; +} + +public struct LinearLayout : IStorageLayout +{ + internal static const LayoutType Layout = LayoutType.Linear; +} + /** Storage interface for accessing neural network parameters. Provides an abstraction for reading/writing parameters from various storage backends @@ -28,6 +43,18 @@ public interface IStorage : IDifferentiablePtrType */ public T read(Address address); + /** + Reads sequential 4 elements starting from the given address and packed into a uint4. + */ + + internal uint4 readUint4(Address baseAddress, Address address) + where DstType : __BuiltinFloatingPointType + where DstType.Differential == DstType; + + internal void writeUint4Atomic(Address baseAddress, Address address, uint4 value) + where SrcType : __BuiltinFloatingPointType + where SrcType.Differential == SrcType; + /** Atomically adds a value to storage (for gradient accumulation). @param[in] address The address to add to. diff --git a/source/standard-modules/neural/ivector.slang b/source/standard-modules/neural/ivector.slang index 6b3547fec3..4738ec82dd 100644 --- a/source/standard-modules/neural/ivector.slang +++ b/source/standard-modules/neural/ivector.slang @@ -12,7 +12,7 @@ and linear algebra operations for neural network computations. @see `InlineVector` @category neural */ -public interface IVector : IDifferentiable +public interface IVector : IDifferentiable, IArrayAccessor where T : __BuiltinFloatingPointType where T.Differential == T { @@ -21,7 +21,7 @@ public interface IVector : IDifferentiable /// The differential type for automatic differentiation. /// @remarks Ensures the differential is also a vector with the same structure. - public associatedtype Differential : IVector; + public associatedtype Differential : IVector; /// Default constructor - initializes vector to zero. public __init(); @@ -36,7 +36,9 @@ public interface IVector : IDifferentiable Array constructor - initializes from an array. @param[in] data Array of N elements to initialize the vector. */ - public __init(T[N] data); + public __init(T[This.Size] data); + + public __init>(InputArray data); /** Copy constructor. @@ -72,16 +74,17 @@ public interface IVector : IDifferentiable - `Storage` must conform to `IStorage` - `Storage.Differential` must conform to `IStorage` - `Storage.Address` must equal `Storage.Differential.Address` (same address type) - - `OutputVector` must conform to `IVector` + - `OutputVector` must conform to `IVector` */ [Differentiable] - public OutputVector linearTransform( + public OutputVector linearTransform( Storage weightStorage, no_diff Storage.Address weightAddress) where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector; + where Layout : IStorageLayout + where OutputVector : IVector; /** Evaluates a linear transformation: output = W * this + bias. @@ -107,7 +110,7 @@ public interface IVector : IDifferentiable - `OutputVector` must conform to `IVector` */ [Differentiable] - public OutputVector linearTransform( + public OutputVector linearTransform( Storage weightStorage, Storage biasStorage, no_diff Storage.Address weightAddress, @@ -115,7 +118,8 @@ public interface IVector : IDifferentiable where Storage : IStorage where Storage.Differential : IStorage where Storage.Address == Storage.Differential.Address - where OutputVector : IVector; + where Layout : IStorageLayout + where OutputVector : IVector; /** Evaluates a linear transformation: output = W * this. @@ -135,10 +139,11 @@ public interface IVector : IDifferentiable - `OutputVector` must conform to `IVector` */ [Differentiable] - public OutputVector linearTransform(Address weightAddress) + public OutputVector linearTransform(Address weightAddress) where Address : IPointerLikeAddress where Address.Differential : IPointerLikeAddress - where OutputVector : IVector; + where Layout : IStorageLayout + where OutputVector : IVector; /** Evaluates a linear transformation: output = W * this + bias. @@ -163,9 +168,10 @@ public interface IVector : IDifferentiable - `OutputVector` must conform to `IVector` */ [Differentiable] - public OutputVector linearTransform( + public OutputVector linearTransform( Address weightAddress, Address biasAddress) where Address : IPointerLikeAddress where Address.Differential : IPointerLikeAddress - where OutputVector : IVector; + where Layout : IStorageLayout + where OutputVector : IVector; } diff --git a/source/standard-modules/neural/neural.slang b/source/standard-modules/neural/neural.slang index 3f772de2e3..175c9ef709 100644 --- a/source/standard-modules/neural/neural.slang +++ b/source/standard-modules/neural/neural.slang @@ -37,3 +37,6 @@ __include "inline-vector"; __include "istorages"; __include "bindless-storage"; __include "buffer-storage"; +__include "accelerate-vector-coopmat"; +__include "vectorized-reader"; +__include "shared-memory-pool"; diff --git a/source/standard-modules/neural/shared-memory-pool.slang b/source/standard-modules/neural/shared-memory-pool.slang new file mode 100644 index 0000000000..32c1a629e1 --- /dev/null +++ b/source/standard-modules/neural/shared-memory-pool.slang @@ -0,0 +1,278 @@ +// Unit test mode is used for unit testing the tiled MMA implementation. +// So we can test this single file by providing -DUNIT_TEST to the compiler. +implementing neural; + +#ifdef UNIT_TEST +#define VISIBILITY_LEVEL public +#else +#define VISIBILITY_LEVEL internal +#endif + +#define Max(A, B) ((A) > (B) ? (A) : (B)) + +internal typealias SPtr = Ptr; + +internal interface ISharedMemoryPool +{ + internal static SPtr getPointer(); +} + +public interface ISharedMemorySize +{ + public static const uint Bytes; +} + +public struct SharedMemoryPool : ISharedMemoryPool +{ + public static const uint sizeInBytes = ShMemSize.Bytes; + internal static groupshared uint4 data[sizeInBytes / sizeof(uint4)]; + VISIBILITY_LEVEL static SPtr getPointer() + { + return __getAddress(data[0]); + } +} + +internal struct SharedMemoryUsage + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + static const bool IsTraining = ExeMode == ExecutionMode.Training; + typealias TileInfoNormal = TileInfo; + typealias TileInfoTransposed = TileInfo; + typealias CMShape = CoopMatShape; + + // Shared memory A is used to load Tile A. The Size Tile A is determined by the height of matrix A and width of CoopMatA. + // The possible shapes of matrix A can be: + // 1. M x K in A * B -> inference (TransposeA = false) + // 2. K x M in A^T * B -> training (TransposeA = true) + // 3. M x N in outer product of dOut and input. -> training (TransposeA = false) + // In the inference mode, tile A size is always [M x CoopMatA_Width]. + // In the training mode, tile A size is either [M x CoopMatA_Width] or [K x CoopMatA_Width], so we need to choose the max value + static const int SharedMemSizeInVectorMatA = !IsTraining ? + (TileInfoNormal.HeightInElementsTileA * CMShape.COLUMN_A) / CMShape.ElementCountPerVector : + (Max(TileInfoNormal.HeightInElementsTileA, TileInfoTransposed.HeightInElementsTileA) * CMShape.COLUMN_A) / CMShape.ElementCountPerVector; + + // Shared memory B is used to load Tile B. The Size Tile B is determined by the height of CoopMatB and width of Tile B. + // The possible shapes of matrix B in inference mode can be: + // 1. K x N in A * B -> inference + // 2. M x N in A^T * B -> training + // 3. N x K in outer product of dOut and input. -> training + // In the inference mode, tile B size is always [CoopMatB_Height x N]. + // In the training mode, tile B size is either [CoopMatB_Height x N] or [CoopMatB_Height x K], so we need to choose the max value. + + // InputSize is K. + static const int TileBWidthForOuterProduct = ((InputSize + CMShape.COLUMN_B - 1) / CMShape.COLUMN_B) * CMShape.COLUMN_B; + static const int SharedMemSizeInVectorMatB = !IsTraining ? + ((TileInfoNormal.WidthInElementsTileB * CMShape.ROW_B) / CMShape.ElementCountPerVector) : + ((Max(TileInfoNormal.WidthInElementsTileB, TileBWidthForOuterProduct) * CMShape.ROW_B) / CMShape.ElementCountPerVector); + + // Shared memory C is used to store the result of CoopMatC. The size is determened by height of CoopMatC and width of Tile C. + // The possible shapes matrix C can only be: + // 1. M x N in A * B + // 2. K x N in A^T * B + // 3. M x K in outer product of dOut and input. + // Therefore the Tile C size is same as the Tile B size. However, the data type of Tile B can only be half, while tile C can be + // both float and half, so we need to take that into account. + static const int SharedMemSizeInVectorMatC = SharedMemSizeInVectorMatB * sizeof(T) / sizeof(half); +} + +public struct SharedMemorySize0 + : ISharedMemorySize + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + typealias ShMemInfo = SharedMemoryUsage; + + // Notice that in the actual implementation, we always reuse shared memory for Tile B and Tile C because they are always used at + // different stages of the computation, and they have the same size. + public static const uint Bytes = + (ShMemInfo.SharedMemSizeInVectorMatA + (ShMemInfo.SharedMemSizeInVectorMatC) * SubgroupCount) * sizeof(uint4); +} + +// The following code is a macro-based implementation of the shared memory size calculation. +// It is used to calculate the shared memory size for a given number of hidden layers. +// The challenge here is that the size of the shared memory has to be compile time constant, however +// slang doesn't really have const_expr function. So the only way to get the compile time constant +// is to use meta programming to generate the code that can be evaluated at compile time. +// Here the algorithm is very simple where we just use divide and conquer to calculate the shared memory size. +// Firstly, we define the base case `SharedMemorySize0` that give an input and output of a layer, we calculate the shared memory size for this layer. +// Then we can define larger number of layers by using the divide and conquer strategy. The reason to use Macro here +// is just to reduce the mount of code that we need to write. But under the hood, the macro will be expanded to: +// `SharedMemorySize1` to `SharedMemorySize15`. + +// Take an example: +// ``` +// DEFINE_SHMEM_SIZE(3, 1, 1, PARAM_3, ARG_3_L, ARG_3_R) +// ``` +// This will be expanded to: +// ``` +// public struct SharedMemorySize3 SHMEM_WHERE { +// internal static const uint a = SharedMemorySize1.Bytes; +// internal static const uint b = SharedMemorySize1.Bytes; +// public static const uint Bytes = Max(a, b); +// } +// ``` +// Where `LN` and `RN` determine how we divide input sequence of layers into two parts. + +// TODO: We shouldn't need such sophisticated meta-programming to achieve this once we have const_expr function +// or we can provide more advanced variadic generic parameters support such as First(...)/Rest(...), so that +// we can define the SharedMemorySize as variadic generic struct instead of these pre-defined generics. +// +// We note that this implementation is not the most efficient way to calculate the shared memory size, because +// we can first find out the max layer size, and then do the remaining calculation. But since this computation is +// not done at run time, so we don't need to worry about the performance, and we can reuse other data structure we +// already have, so it's easiest way to implement this. + + +#define UNPACK(...) __VA_ARGS__ + +// 2. Define your helper macros +#define SHMEM_WHERE where T : __BuiltinFloatingPointType where T.Differential == T +#define SHMEM_BASE T, Target, ExeMode, SubgroupSize, SubgroupCount + +// 3. The Core Macro - note the removed spaces around UNPACK +#define DEFINE_SHMEM_SIZE(N, LN, RN, ARGS, L_VALS, R_VALS) \ + public struct SharedMemorySize##N \ + : ISharedMemorySize \ + SHMEM_WHERE \ + { \ + internal static const uint a = SharedMemorySize##LN.Bytes; \ + internal static const uint b = SharedMemorySize##RN.Bytes; \ + public static const uint Bytes = Max(a, b); \ + } + +#define PARAM_1 (uint S0, uint S1, uint S2) +#define ARG_1 (S0, S1, S2) +#define ARG_1_L (S0, S1) +#define ARG_1_R (S1, S2) + +#define PARAM_2 (UNPACK PARAM_1, uint S3) +#define ARG_2 (UNPACK ARG_1, S3) +#define ARG_2_L (S0, S1, S2) +#define ARG_2_R (S2, S3) + +#define PARAM_3 (UNPACK PARAM_2, uint S4) +#define ARG_3 (UNPACK ARG_2, S4) +#define ARG_3_L (S0, S1, S2) +#define ARG_3_R (S2, S3, S4) + +DEFINE_SHMEM_SIZE(1, 0, 0, PARAM_1, ARG_1_L, ARG_1_R) +DEFINE_SHMEM_SIZE(2, 1, 0, PARAM_2, ARG_2_L, ARG_2_R) +DEFINE_SHMEM_SIZE(3, 1, 1, PARAM_3, ARG_3_L, ARG_3_R) + +// from 4 to 7 +#define PARAM_4 (UNPACK PARAM_3, uint S5) +#define ARG_4 (UNPACK ARG_3, S5) +#define ARG_4_R (S4, S5) + +#define PARAM_5 (UNPACK PARAM_4, uint S6) +#define ARG_5 (UNPACK ARG_4, S6) +#define ARG_5_R (UNPACK ARG_4_R, S6) + +#define PARAM_6 (UNPACK PARAM_5, uint S7) +#define ARG_6 (S0, S1, S2, S3, S4, S5, S6, S7) +#define ARG_6_R (UNPACK ARG_5_R, S7) + +#define PARAM_7 (UNPACK PARAM_6, uint S8) +#define ARG_7 (UNPACK ARG_6, S8) +#define ARG_7_R (UNPACK ARG_6_R, S8) + +DEFINE_SHMEM_SIZE(4, 3, 0, PARAM_4, ARG_3, ARG_4_R) +DEFINE_SHMEM_SIZE(5, 3, 1, PARAM_5, ARG_3, ARG_5_R) +DEFINE_SHMEM_SIZE(6, 3, 2, PARAM_6, ARG_3, ARG_6_R) +DEFINE_SHMEM_SIZE(7, 3, 3, PARAM_7, ARG_3, ARG_7_R) + +// from 8 to 15 +#define PARAM_8 (UNPACK PARAM_7, uint S9) +#define ARG_8 (UNPACK ARG_7, S9) +#define ARG_8_R (S8, S9) + +#define PARAM_9 (UNPACK PARAM_8, uint S10) +#define ARG_9 (UNPACK ARG_8, S10) +#define ARG_9_R (UNPACK ARG_8_R, S10) + +#define PARAM_10 (UNPACK PARAM_9, uint S11) +#define ARG_10 (UNPACK ARG_9, S11) +#define ARG_10_R (UNPACK ARG_9_R, S11) + +#define PARAM_11 (UNPACK PARAM_10, uint S12) +#define ARG_11 (UNPACK ARG_10, S12) +#define ARG_11_R (UNPACK ARG_10_R, S12) + +#define PARAM_12 (UNPACK PARAM_11, uint S13) +#define ARG_12 (UNPACK ARG_11, S13) +#define ARG_12_R (UNPACK ARG_11_R, S13) + +#define PARAM_13 (UNPACK PARAM_12, uint S14) +#define ARG_13 (UNPACK ARG_12, S14) +#define ARG_13_R (UNPACK ARG_12_R, S14) + +#define PARAM_14 (UNPACK PARAM_13, uint S15) +#define ARG_14 (UNPACK ARG_13, S15) +#define ARG_14_R (UNPACK ARG_13_R, S15) + +#define PARAM_15 (UNPACK PARAM_14, uint S16) +#define ARG_15 (UNPACK ARG_14, S16) +#define ARG_15_R (UNPACK ARG_14_R, S16) + +DEFINE_SHMEM_SIZE(8, 7, 0, PARAM_8, ARG_7, ARG_8_R) +DEFINE_SHMEM_SIZE(9, 7, 1, PARAM_9, ARG_7, ARG_9_R) +DEFINE_SHMEM_SIZE(10, 7, 2, PARAM_10, ARG_7, ARG_10_R) +DEFINE_SHMEM_SIZE(11, 7, 3, PARAM_11, ARG_7, ARG_11_R) +DEFINE_SHMEM_SIZE(12, 7, 4, PARAM_12, ARG_7, ARG_12_R) +DEFINE_SHMEM_SIZE(13, 7, 5, PARAM_13, ARG_7, ARG_13_R) +DEFINE_SHMEM_SIZE(14, 7, 6, PARAM_14, ARG_7, ARG_14_R) +DEFINE_SHMEM_SIZE(15, 7, 7, PARAM_15, ARG_7, ARG_15_R) + +// Slang doesn't support generic overloading, so we cannot provide the one generic with different number of parameters. +public struct SharedMemorySize + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + #define STRIP_PARENS(x) STRIP_PARENS_I x + #define STRIP_PARENS_I(...) __VA_ARGS__ + + public typealias OfLayer1 = SharedMemorySize0; + public typealias OfLayer2 = SharedMemorySize1; + public typealias OfLayer3 = SharedMemorySize2; + public typealias OfLayer4 = SharedMemorySize3; + public typealias OfLayer5 = SharedMemorySize4; + public typealias OfLayer6 = SharedMemorySize5; + public typealias OfLayer7 = SharedMemorySize6; + public typealias OfLayer8 = SharedMemorySize7; + public typealias OfLayer9 = SharedMemorySize8; + public typealias OfLayer10 = SharedMemorySize9; + public typealias OfLayer11 = SharedMemorySize10; + public typealias OfLayer12 = SharedMemorySize11; + public typealias OfLayer13 = SharedMemorySize12; + public typealias OfLayer14 = SharedMemorySize13; + public typealias OfLayer15 = SharedMemorySize14; + public typealias OfLayer16 = SharedMemorySize15; +} + +#if 0 + +// We should implement First/Rest syntax to something like this. +interface IVal { + static const int Value; +} +SharedMemorySize +{ + static const uint SharedMemSizeInBytes = + max(SharedMemorySize, SharedMemorySize); +} + +// Slang doesn't support generic overloading, therefore we cannot provide the pre-defined generics that adds different number of HiddenSize. + +public struct SharedMemorySize + where T : __BuiltinFloatingPointType + where T.Differential == T +{ + typealias ShMemInfo = SharedMemoryUsage; + + // Notice that in the actual implementation, we always reuse shared memory for Tile B and Tile C because they are always used at + // different stages of the computation, and they have the same size. + static const uint SharedMemSizeInBytes = + (ShMemInfo.SharedMemSizeInVectorMatA + (ShMemInfo.SharedMemSizeInVectorMatB) * SubgroupCount) * sizeof(uint4); +} +#endif diff --git a/source/standard-modules/neural/vectorized-reader.slang b/source/standard-modules/neural/vectorized-reader.slang new file mode 100644 index 0000000000..6295837a59 --- /dev/null +++ b/source/standard-modules/neural/vectorized-reader.slang @@ -0,0 +1,211 @@ +implementing neural; + +#ifdef UNIT_TEST +#define VISIBILITY_LEVEL public +#else +#define VISIBILITY_LEVEL internal +#endif + + +internal interface IArrayAccessor : IRWArray +{ + internal void atomicAdd(int index, T value) + { + static_assert(false, "atomicAdd is not supported for IArrayAccessor"); + } +} + +internal extension RWStructuredBuffer : IArrayAccessor +{ + [ForceInline] + override internal void atomicAdd(int index, T value) + { + __atomic_add(this[index], value); + } +} + +internal extension Array : IArrayAccessor +{ + internal __subscript(int index) -> T + { + [ForceInline] + get { return this[index]; } + + [ForceInline] + set { this[index] = newValue; } + } +} + +VISIBILITY_LEVEL enum AccessOp : uint32_t +{ + READ, + WRITE, + ACCUMULATE, + ATOMIC_ADD, +} + +#define COMMON_TYPE_CONSTRAINTS \ + where T : __BuiltinFloatingPointType \ + where U : __BuiltinFloatingPointType \ + where BufferType : IArrayAccessor + +[ForceInline] +internal static void readOneElement(BufferType buffer, int bufferIdx, int elementIdx, inout uint result) + COMMON_TYPE_CONSTRAINTS +{ + const uint shift = BitsShiftPerRead * elementIdx; + T convertedValue; + convertedValue = __realCast(buffer[bufferIdx]); + switch (NBytes) + { + case 1: + result |= uint(bit_cast(convertedValue)) << shift; + break; + case 2: + result |= uint(bit_cast(convertedValue)) << shift; + break; + case 4: + result |= uint(bit_cast(convertedValue)) << shift; + break; + default: + static_assert(false, "Unsupported data type T"); + } +} + +[ForceInline] +internal static void writeOneElement(inout BufferType buffer, int bufferIdx, int elementIdx, uint value) + COMMON_TYPE_CONSTRAINTS +{ + const uint shift = BitsShiftPerWrite * elementIdx; + U convertedValue; + switch (NBytes) + { + case 1: + convertedValue = __realCast(bit_cast((uint8_t)(value >> shift))); + break; + case 2: + convertedValue = __realCast(bit_cast((uint16_t)(value >> shift))); + break; + case 4: + convertedValue = __realCast(bit_cast((uint)(value >> shift))); + break; + default: + static_assert(false, "Unsupported data type T"); + } + + switch (Op) + { + case AccessOp.WRITE: + buffer[bufferIdx] = convertedValue; + break; + case AccessOp.ACCUMULATE: + buffer[bufferIdx] = buffer[bufferIdx] + convertedValue; + break; + case AccessOp.ATOMIC_ADD: + buffer.atomicAdd(bufferIdx, convertedValue); + break; + default: + static_assert(false, "Unsupported access operation"); + } +} + +[ForceInline] +internal static void accessUint4Aligned(inout BufferType buffer, int startIndex, inout uint4 value) + COMMON_TYPE_CONSTRAINTS +{ + const int nBytes = sizeof(T); + const int WritePerElement = 4 / nBytes; + const int BitsShiftPerWrite = 32 / WritePerElement; + + if (Op == AccessOp.READ) + value = uint4(0, 0, 0, 0); + + [ForceUnroll] + for (int i = 0; i < 4; i++) + { + [ForceUnroll] + for (int j = 0; j < WritePerElement; j++) + { + int index = startIndex + i * WritePerElement + j; + switch (Op) + { + case AccessOp.READ: + readOneElement(buffer, index, j, value[i]); + break; + case AccessOp.WRITE: + case AccessOp.ACCUMULATE: + case AccessOp.ATOMIC_ADD: + writeOneElement(buffer, index, j, value[i]); + break; + default: + static_assert(false, "Unsupported access operation"); + } + } + } +} + +[ForceInline] +internal void accessUint4(BufferType buffer, int baseIndex, int startIndex, inout uint4 value) + COMMON_TYPE_CONSTRAINTS +{ + if (IsAligned) + { + // Call the aligned version of readUint4 which is branchless. + accessUint4Aligned(buffer, startIndex, value); + return; + } + + if (Op == AccessOp.READ) + value = uint4(0, 0, 0, 0); + + // T is the type of source (read) or destination (write) data type. We will always pack few elements into a uint4. + // So T will determine how many elements we can pack into a uint4. + // If U is different from T, we will first convert from U to T (in read operation) or from T to U (in write operation). + // But U will not determined how many elements we can read or write, only T will. + const int nBytes = sizeof(T); + const int ReadPerElement = 4 / nBytes; + const int BitsShiftPerRead = 32 / ReadPerElement; + + const int x = (startIndex - baseIndex) % Stride; + + // end address of this read [address+length-1] + const int endAddress = (x + 4 * ReadPerElement - 1); + + // this is same as paddingCount = endAddress < AlignedStride ? 0 : AlignedStride - endAddress + 1 + const int paddingCount = max(0, endAddress - Stride + 1); + const int elementsToRead = (4 * ReadPerElement) - paddingCount; + + [ForceUnroll] + for (int i = 0; i < 4; i++) + { + int offset = i * ReadPerElement; + [ForceUnroll] + for (int j = 0; j < ReadPerElement; j++) + { + // 4 * ReadPerElement is the total number of elements we can read from the buffer. + // paddingCount is the number of the elements we need to pad. + // e.g. if ReadPerElement is 2, paddingCount is 4.Because (4 * 2 - 4 == 4), so we can + // just stop reading when offset bigger than 3. + offset += j; + if (offset >= elementsToRead) + { + return; + } + + int index = (startIndex + offset); + switch (Op) + { + case AccessOp.READ: + readOneElement(buffer, index, j, value[i]); + break; + case AccessOp.WRITE: + case AccessOp.ACCUMULATE: + case AccessOp.ATOMIC_ADD: + writeOneElement(buffer, index, j, value[i]); + break; + default: + static_assert(false, "Unsupported access operation"); + } + } + } +} diff --git a/tests/cooperative-matrix/length.slang b/tests/cooperative-matrix/length.slang index 16d90259fe..920a3535d5 100644 --- a/tests/cooperative-matrix/length.slang +++ b/tests/cooperative-matrix/length.slang @@ -1,13 +1,19 @@ -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -output-using-type -emit-spirv-directly -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-cuda -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK-VK):-vk -output-using-type -emit-spirv-directly +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK-CUDA):-cuda -output-using-type // Note the length is NOT row * column. // When the memory scope is set to subgroup, each thread gets 16 * 16 / 32 = 8 where 32 is the value used in `numthreads`. -//CHK:8 -//CHK:4 -//CHK:16 -//CHK:4 +//CHK-VK:8 +//CHK-VK:4 +//CHK-VK:16 +//CHK-VK:4 + +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment +//CHK-CUDA:8 +//CHK-CUDA:16 +//CHK-CUDA:16 +//CHK-CUDA:16 //TEST_INPUT:ubuffer(stride=4, count=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -27,7 +33,10 @@ void computeMain() outputBuffer[2] = CoopMat.GetLength(); case cuda: + // for f16 type, with any shapes, each fragment of A and B is a vector of 8 f16x2 registers, each registers contains 2 elements. outputBuffer[1] = CoopMat.GetLength(); + + // for int8 type, m32n8k16 shape, each fragment of A is a vector of 4 b32 registers, each registers contains 4 elements. outputBuffer[2] = CoopMat.GetLength(); } diff --git a/tests/cooperative-matrix/load-store-groupshared.slang b/tests/cooperative-matrix/load-store-groupshared.slang index 9df0cfbad2..d53adb5a7f 100644 --- a/tests/cooperative-matrix/load-store-groupshared.slang +++ b/tests/cooperative-matrix/load-store-groupshared.slang @@ -10,24 +10,48 @@ // CHECK-NEXT: 7 // CHECK-NEXT: 8 +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 +#pragma warning(disable:41017) //TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=input ByteAddressBuffer input; -//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +// TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer RWStructuredBuffer outputBuffer; +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer1 +RWStructuredBuffer outputBuffer1; + using namespace linalg; groupshared int32_t[256] tempShared; +groupshared uint4[256 / 4] tempShared1; [numthreads(32, 1, 1)] +[shader("compute")] void computeMain() { - let stride = 16; + let stride = 4; let mat = coopMatLoad(input, 0, stride); mat.Store(tempShared, 0, stride); let result = coopMatLoad(tempShared, 0, stride); result.Store(outputBuffer, 0, stride); + + Ptr tempSharedPtr = __getAddress(tempShared[0]); + let matC = CoopMat.Load(tempSharedPtr, stride); + + + Ptr tempSharedPtr1 = __getAddress(tempShared1[0]); + matC.Store(tempSharedPtr1, stride); + + let matC1 = CoopMat.Load(tempSharedPtr1, stride); + matC1.Store(outputBuffer1, 0, stride); } diff --git a/tests/neural/basic-coopmat-vector-test.slang b/tests/neural/basic-coopmat-vector-test.slang new file mode 100644 index 0000000000..b3c2add477 --- /dev/null +++ b/tests/neural/basic-coopmat-vector-test.slang @@ -0,0 +1,257 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -xslang -DTEST_HALF=0 -emit-spirv-directly +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +import neural; + +#if TEST_HALF +typealias ElementType = half; +#else +typealias ElementType = float; +#endif + +// set up a 2x4 matrix for input parameters, the last 2 elements are for bias +//TEST_INPUT: ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0], stride=4):name=parametersFloat +// 1 2 3 4 +// 5 6 7 8 +// bias = {9.0, 10.0} +RWStructuredBuffer parametersFloat; + +//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=2):name=parameters +RWStructuredBuffer parameters; + +// Create a buffer to store the test result +//TEST_INPUT: ubuffer(data=[0 0], stride=4):out,name=testResult +RWStructuredBuffer testResult; + +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):name=dInput +RWStructuredBuffer dInput; + +// set up a 2x4 matrix for derivative of parameters, the last 2 elements are for derivative of bias +//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):name=dParameters +RWStructuredBuffer dParameters; + +typealias BufferStorage = StructuredBufferStorage; + +static const int InputSize = 4; +static const int OutputSize = 2; +static const int BatchSize = 32; +static const int SubgroupSize = 32; +static const int workgroupCount = BatchSize / SubgroupSize; + +typealias ShMemSize = SharedMemorySize< ElementType, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, BatchSize / SubgroupSize>; +typealias ShMemSizeLayer1 = ShMemSize.OfLayer1; + +typealias SPtr = Ptr; + +[Differentiable] +OutputVector MatVecMul( + InputVector input, + BufferStorage weightStorage, + BufferStorage.Address weightAddress) + where InputVector : IVector + where OutputVector : IVector +{ + var outputVec = input.linearTransform(weightStorage, weightAddress); + return outputVec; +} + +[Differentiable] +OutputVector MatVecMulAdd( + InputVector input, + BufferStorage weightStorage, + BufferStorage biasStorage, + BufferStorage.Address weightAddress, + BufferStorage.Address biasAddress) + where InputVector : IVector + where OutputVector : IVector +{ + var outputVec = input.linearTransform(weightStorage, biasStorage, weightAddress, biasAddress); + return outputVec; +} + +// Basic test on MatMul without bias, this test covers both forward and backward pass +void BasicTestWithoutBias(int tid, int resIndex) +{ + typealias ShMemPool = SharedMemoryPool; + typealias InVectorType = WaveTangledVector; + typealias OutVectorType = WaveTangledVector; + + ElementType[InputSize] inputData = { ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0) }; + InVectorType input = InVectorType(inputData); + + BufferStorage weightStorage = BufferStorage(parameters); + BufferStorage dweightStorage = BufferStorage(dParameters); + BufferStorage.Address weightAddress = 0; + + // Run the forward pass + let outputVec = MatVecMul(input, weightStorage, weightAddress); + + // serialRead<16, half>(tid, __getAddress(shMem[0])); + // serialRead<16, half>(tid, __getAddress(shMem[0]) + 32); + + // (1*1 + 2*2 + 3*3 + 4*4) = 30.0 + // (5*1 + 6*2 + 7*3 + 8*4) = 70.0 + bool isPassed = true; + isPassed = isPassed && (outputVec[0] == 30.0 && outputVec[1] == 70.0); + + var weightDiffPair = DifferentialPtrPair(weightStorage, dweightStorage); + let dRes = OutVectorType(1.0f); + var dPair = diffPair(input); + + // Run the backward pass + // dInput = W^T * dOutput + // dInput = {6, 8, 10, 12} + bwd_diff(MatVecMul) + (dPair, weightDiffPair, weightAddress, dRes); + + isPassed = isPassed && + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; + + // dW = dOutput * dInput^T + // dW = [1, 1]^T * [1, 2, 3, 4] + // = [[1, 2, 3, 4]; [1, 2, 3, 4]] + // But since it's accumulated cross 32 threads, so the result is 32 times of the original result. + // So the result should be 32 * [[1, 2, 3, 4]; [1, 2, 3, 4]] = [[32, 64, 96, 128]; [32, 64, 96, 128]] + isPassed = isPassed && + dParameters[0] == 32.0 && dParameters[1] == 64.0 && dParameters[2] == 96.0 && dParameters[3] == 128.0 && + dParameters[4] == 32.0 && dParameters[5] == 64.0 && dParameters[6] == 96.0 && dParameters[7] == 128.0; + + isPassed = WaveActiveAllTrue(isPassed); + if (tid == 0) + { + testResult[resIndex] = isPassed ? 1 : 0; + } +} + +// Basic test on MatMul with bias, this test covers both forward and backward pass +void BasicTestWithBias(int tid, int resIndex) +{ + typealias ShMemPool = SharedMemoryPool; + typealias InVectorType = WaveTangledVector; + typealias OutVectorType = WaveTangledVector; + + ElementType[InputSize] inputData = { ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0) }; + InVectorType input = InVectorType(inputData); + + BufferStorage weightStorage = BufferStorage(parameters); + BufferStorage biasStorage = BufferStorage(parameters); + BufferStorage dweightStorage = BufferStorage(dParameters); + BufferStorage dbiasStorage = BufferStorage(dParameters); + + BufferStorage.Address weightAddress = 0; + BufferStorage.Address biasAddress = 8; + + // Run the forward pass + let outputVec = MatVecMulAdd(input, weightStorage, biasStorage, weightAddress, biasAddress); + + bool isPassed = true; + isPassed = isPassed && (outputVec[0] == ElementType(39.0) && outputVec[1] == ElementType(80.0)); + + + var weightDiffPair = DifferentialPtrPair(weightStorage, dweightStorage); + var biasDiffPair = DifferentialPtrPair(biasStorage, dbiasStorage); + let dOutput = OutVectorType(1.0); + var dPair = diffPair(input); + + // Run the backward pass + // dInput = W^T * dOutput + // dInput = {6, 8, 10, 12} + bwd_diff(MatVecMulAdd) + ( dPair, weightDiffPair, biasDiffPair, weightAddress, biasAddress, dOutput); + + isPassed = isPassed && + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; + + + // dW = dOutput * dInput^T + // dW = [1, 1]^T * [1, 2, 3, 4] + // = [[1, 2, 3, 4]; [1, 2, 3, 4]] + isPassed = isPassed && + dParameters[0] == 32.0 && dParameters[1] == 64.0 && dParameters[2] == 96.0 && dParameters[3] == 128.0 && + dParameters[4] == 32.0 && dParameters[5] == 64.0 && dParameters[6] == 96.0 && dParameters[7] == 128.0; + + // dBias = dOutput + // dBias = {1, 1} + isPassed = isPassed && + dParameters[8] == 32.0 && dParameters[9] == 32.0; + + + isPassed = WaveActiveAllTrue(isPassed); + if (tid == 0) + { + testResult[resIndex] = isPassed ? 1 : 0; + } +} + + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int id = 0; id < 16; id++) + { + printf("tid: %d\n", id); + int strideInVector = Stride / (sizeof(uint4) / sizeof(T)); + for (int i = 0; i < strideInVector; i++) + { + uint4 values = sharedMem[id * strideInVector + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + printf("%f ", bit_cast(value)); + } + } + } + printf("\n"); + } +} + +void cleanupDParameters() +{ + for (int i = 0; i < 10; i++) + { + dParameters[i] = ElementType(0.0); + } +} + +void setupParameters(uint tid) +{ + if (tid == 0) + { + for (int i = 0; i < 10; i++) + { + parameters[i] = ElementType(parametersFloat[i]); + } + } +} + +[shader("compute")] +[numthreads(BatchSize, 1, 1)] +void computeMain(uint tid : SV_DispatchThreadID) +{ + setupParameters(tid); + GroupMemoryBarrierWithWaveSync(); + BasicTestWithoutBias(tid, 0); + // BUFFER: 1 + + cleanupDParameters(); + BasicTestWithBias(tid, 1); + // BUFFER: 1 +} diff --git a/tests/neural/basic-inline-vector-test-bindless-storage.slang b/tests/neural/basic-inline-vector-test-bindless-storage.slang index 67c0dcebb8..c7e7a3d995 100644 --- a/tests/neural/basic-inline-vector-test-bindless-storage.slang +++ b/tests/neural/basic-inline-vector-test-bindless-storage.slang @@ -62,10 +62,10 @@ uniform RWStructuredBuffer parametersFloat; OutputVector TestInlineVectorMatMul( InputVector input, Address address) - where InputVector : IVector - where OutputVector : IVector + where InputVector : IVector + where OutputVector : IVector { - var outputVec = input.linearTransform<2, Address, OutputVector>(address); + var outputVec = input.linearTransform(address); return outputVec; } @@ -74,10 +74,10 @@ OutputVector TestInlineVectorMatMulAdd( InputVector input, Address weightAddress, Address biasAddress) - where InputVector : IVector - where OutputVector : IVector + where InputVector : IVector + where OutputVector : IVector { - var outputVec = input.linearTransform<2, Address, OutputVector>(weightAddress, biasAddress); + var outputVec = input.linearTransform(weightAddress, biasAddress); return outputVec; } @@ -102,12 +102,12 @@ bool BasicTestWithoutBias() // Run the backward pass // dInput = W^T * dOutput - // dInput = {3, 7, 11, 15} + // dInput = {6, 8, 10, 12} bwd_diff(TestInlineVectorMatMul, InlineVector>) (dPair, weightDiffPair, dRes); isPassed = isPassed && - dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; // dW = dOutput * dInput^T // dW = [1, 1]^T * [1, 2, 3, 4] @@ -143,12 +143,12 @@ bool BasicTestWithBias() // Run the backward pass // dInput = W^T * dOutput - // dInput = {3, 7, 11, 15} + // dInput = {6, 8, 10, 12} bwd_diff(TestInlineVectorMatMulAdd, InlineVector>) (dPair, weightDiffPair, biasDiffPair, dOutput); isPassed = isPassed && - dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; // dW = dOutput * dInput^T // dW = [1, 1]^T * [1, 2, 3, 4] diff --git a/tests/neural/basic-inline-vector-test.slang b/tests/neural/basic-inline-vector-test.slang index 8b7c47e698..23f287dfc8 100644 --- a/tests/neural/basic-inline-vector-test.slang +++ b/tests/neural/basic-inline-vector-test.slang @@ -1,10 +1,10 @@ -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -xslang -DTEST_HALF=0 -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-dx12 -compute -shaderobj -profile cs_6_6 -xslang -experimental-feature -output-using-type -xslang -DTEST_HALF=0 -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-mtl -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-dx12 -compute -shaderobj -profile cs_6_6 -xslang -experimental-feature -output-using-type -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-mtl -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 // Currently, only CUDA supports atomicAdd on half. So we can only test fp16 on CUDA. -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 import neural; #if TEST_HALF @@ -41,10 +41,10 @@ OutputVector TestInlineVectorMatMul( InputVector input, BufferStorage weightStorage, BufferStorage.Address weightAddress) - where InputVector : IVector - where OutputVector : IVector + where InputVector : IVector + where OutputVector : IVector { - var outputVec = input.linearTransform<2, BufferStorage, OutputVector>(weightStorage, weightAddress); + var outputVec = input.linearTransform(weightStorage, weightAddress); return outputVec; } @@ -55,10 +55,10 @@ OutputVector TestInlineVectorMatMulAdd( BufferStorage biasStorage, BufferStorage.Address weightAddress, BufferStorage.Address biasAddress) - where InputVector : IVector - where OutputVector : IVector + where InputVector : IVector + where OutputVector : IVector { - var outputVec = input.linearTransform<2, BufferStorage, OutputVector>(weightStorage, biasStorage, weightAddress, biasAddress); + var outputVec = input.linearTransform(weightStorage, biasStorage, weightAddress, biasAddress); return outputVec; } @@ -85,12 +85,12 @@ bool BasicTestWithoutBias() // Run the backward pass // dInput = W^T * dOutput - // dInput = {3, 7, 11, 15} + // dInput = {6, 8, 10, 12} bwd_diff(TestInlineVectorMatMul, InlineVector>) (dPair, weightDiffPair, weightAddress, dRes); isPassed = isPassed && - dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; // dW = dOutput * dInput^T // dW = [1, 1]^T * [1, 2, 3, 4] @@ -132,12 +132,12 @@ bool BasicTestWithBias() // Run the backward pass // dInput = W^T * dOutput - // dInput = {3, 7, 11, 15} + // dInput = {6, 8, 10, 12} bwd_diff(TestInlineVectorMatMulAdd, InlineVector>) (dPair, weightDiffPair, biasDiffPair, weightAddress, biasAddress, dOutput); isPassed = isPassed && - dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + dPair.d[0] == 6.0 && dPair.d[1] == 8.0 && dPair.d[2] == 10.0 && dPair.d[3] == 12.0; // dW = dOutput * dInput^T // dW = [1, 1]^T * [1, 2, 3, 4] diff --git a/tests/neural/basic-pointer-address-extension.slang b/tests/neural/basic-pointer-address-extension.slang new file mode 100644 index 0000000000..4633da5989 --- /dev/null +++ b/tests/neural/basic-pointer-address-extension.slang @@ -0,0 +1,183 @@ +// This test is to test InlineVector with Bindless Storage. + +// Both Pointer and DescriptorHandle are belong to this category. However, since there is issue on our test-infrasture on Vulkan for +// the DescriptorHandle support, we can only test Pointer for Vulkan now. See issue #8631. + +// We don't want to enable the pointer test on glsl, because there is a bug in emitting glsl code. See issue #8630. + +// Currently, only CUDA supports atomicAdd on half. So we can only test fp16 on CUDA. + + +//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 +//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -output-using-type -emit-spirv-directly -xslang -experimental-feature -xslang -DTEST_HALF=0 +//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-mtl -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 + +import neural; + +#if TEST_HALF +typealias ElementType = half; +#else +typealias ElementType = float; +#endif + +// Create a buffer to store the test result +//TEST_INPUT: ubuffer(data=[0 0], stride=4):out,name=testResult +RWStructuredBuffer testResult; + + +// set up a 2x4 matrix for input parameters, the last 2 elements are for bias +//TEST_INPUT: set parametersFloat = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0], stride=4) +// 1 2 3 4 +// 5 6 7 8 +// bias = {9.0, 10.0} +uniform RWStructuredBuffer parametersFloat; + +//TEST_INPUT: set parameters = ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=2) +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):name=dInput + +// set up a 2x4 matrix for derivative of parameters, the last 2 elements are for derivative of bias +//TEST_INPUT: set dParameters = ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4) + + +uniform ElementType* parameters; +uniform ElementType* dInput; +uniform ElementType* dParameters; + + +[Differentiable] +OutputVector TestInlineVectorMatMul( + InputVector input, + ElementType* parameters) + where InputVector : IVector + where OutputVector : IVector +{ + var outputVec = input.linearTransform<2, ElementType*, OutputVector>(parameters); + return outputVec; +} + +// [Differentiable] +// OutputVector TestInlineVectorMatMulAdd( +// InputVector input, +// ElementType* weightAddress, +// ElementType* biasAddress) +// where InputVector : IVector +// where OutputVector : IVector +// { +// var outputVec = input.linearTransform<2, ElementType*, OutputVector>(weightAddress, biasAddress); +// return outputVec; +// } + +// Basic test on MatMul without bias, this test covers both forward and backward pass +bool BasicTestWithoutBias() +{ + ElementType[4] inputData = {ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0)}; + let input = InlineVector(inputData); + + // Run the forward pass + let outputVec = TestInlineVectorMatMul, InlineVector>( input, parameters); + + // (1*1 + 2*2 + 3*3 + 4*4) = 30.0 + // (5*1 + 6*2 + 7*3 + 8*4) = 70.0 + bool isPassed = (outputVec[0] == 30.0 && outputVec[1] == 70.0); + + var weightDiffPair = DifferentialPtrPair(parameters, dParameters); + let dRes = InlineVector(1.0f); + var dPair = diffPair(input); + + // Run the backward pass + // dInput = W^T * dOutput + // dInput = {3, 7, 11, 15} + bwd_diff(TestInlineVectorMatMul, InlineVector>) + (dPair, weightDiffPair, dRes); + + isPassed = isPassed && + dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + + // dW = dOutput * dInput^T + // dW = [1, 1]^T * [1, 2, 3, 4] + // = [[1, 2, 3, 4]; [1, 2, 3, 4]] + isPassed = isPassed && + dParameters[0] == 1.0 && dParameters[1] == 2.0 && dParameters[2] == 3.0 && dParameters[3] == 4.0 && + dParameters[4] == 1.0 && dParameters[5] == 2.0 && dParameters[6] == 3.0 && dParameters[7] == 4.0; + + return isPassed; +} + +// Basic test on MatMul with bias, this test covers both forward and backward pass +// bool BasicTestWithBias() +// { +// ElementType[4] inputData = {ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0)}; +// let input = InlineVector(inputData); +// Address weightAddress = Address(parameters); +// Address biasAddress = weightAddress.getOffset(8); +// Address dWeightAddress = Address(dParameters); +// Address dBiasAddress = dWeightAddress.getOffset(8); + +// // Run the forward pass +// let outputVec = TestInlineVectorMatMulAdd, InlineVector>(input, weightAddress, biasAddress); + +// // (1*1 + 2*2 + 3*3 + 4*4) + 9.0 = 39.0 +// // (5*1 + 6*2 + 7*3 + 8*4) + 10.0 = 80.0 +// bool isPassed = (outputVec[0] == 39.0 && outputVec[1] == 80.0); + +// var weightDiffPair = DifferentialPtrPair
(weightAddress, dWeightAddress); +// var biasDiffPair = DifferentialPtrPair
(biasAddress, dBiasAddress); +// let dOutput = InlineVector(1.0); +// var dPair = diffPair(input); + +// // Run the backward pass +// // dInput = W^T * dOutput +// // dInput = {3, 7, 11, 15} +// bwd_diff(TestInlineVectorMatMulAdd, InlineVector>) +// (dPair, weightDiffPair, biasDiffPair, dOutput); + +// isPassed = isPassed && +// dPair.d[0] == 3.0 && dPair.d[1] == 7.0 && dPair.d[2] == 11.0 && dPair.d[3] == 15.0; + +// // dW = dOutput * dInput^T +// // dW = [1, 1]^T * [1, 2, 3, 4] +// // = [[1, 2, 3, 4]; [1, 2, 3, 4]] +// isPassed = isPassed && +// dParameters[0] == 1.0 && dParameters[1] == 2.0 && dParameters[2] == 3.0 && dParameters[3] == 4.0 && +// dParameters[4] == 1.0 && dParameters[5] == 2.0 && dParameters[6] == 3.0 && dParameters[7] == 4.0; + +// // dBias = dOutput +// // dBias = {1, 1} +// isPassed = isPassed && +// dParameters[8] == 1.0 && dParameters[9] == 1.0; + +// return isPassed; +// } + +void cleanupDParameters() +{ + for (int i = 0; i < 10; i++) + { + dParameters[i] = ElementType(0.0); + } +} + +void setupParameters() +{ + for (int i = 0; i < 10; i++) + { + parameters[i] = ElementType(parametersFloat[i]); + } +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + setupParameters(); + + testResult[0] = BasicTestWithoutBias(); + + // cleanupDParameters(); + + // testResult[1] = BasicTestWithBias(); + + // BUFFER: 1 + // BUFFER: 1 +} diff --git a/tests/neural/bias-sum-reduce.slang b/tests/neural/bias-sum-reduce.slang new file mode 100644 index 0000000000..b641fc9838 --- /dev/null +++ b/tests/neural/bias-sum-reduce.slang @@ -0,0 +1,193 @@ + +// On Vulkan, we can only test float type for now because atomicAdd is not supported for half on our CI machine. +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=1 +import neural; +#pragma warning(disable: 41017) + + +// TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4, count=4):out, name=resultBuffer +RWStructuredBuffer resultBuffer; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// TEST_INPUT: ubuffer(stride=4, count=64):name=outputBuffer +RWStructuredBuffer outputBuffer; + +typealias BufferStorage = StructuredBufferStorage; + +static const int SubgroupSize = 32; +static const int WorkgroupSize = 128; +static const int workgroupCount = WorkgroupSize / SubgroupSize; +static const int InputSize = 8; +static const int MaxOutputSize = 64; + +typealias ShMemSize = SharedMemorySize< DType, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, workgroupCount>; +static const int ShMemSizeInBytes = ShMemSize.OfLayer1.Bytes; + +groupshared uint4 s_sharedMemory[ShMemSizeInBytes / sizeof(uint4)]; + +typealias SPtr = Ptr; + +// Basic test on MatMul without bias, this test covers both forward and backward pass +void TestBiasSumReduce(uint tid) +{ + BufferStorage storage = BufferStorage(outputBuffer); + + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedM; + const int InSize = MMA.Uint4AlignedK; + + DType dOutVector[OutSize] = {}; + + for (int i = 0; i < OutputSize; i++) + { + dOutVector[i] = DType((i + 1) * 0.01); + } + + SPtr shPtr = __getAddress(s_sharedMemory[0]); + + uint subgroupId = tid / SubgroupSize; + + // perform dOut OPA Input + MMA.sumReduceRows(shPtr, dOutVector, subgroupId, storage, 0); + + // serialRead<16, DType>(tid, shPtr); +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int id = 0; id < 16; id++) + { + printf("tid: %d\n", id); + int strideInVector = Stride / (sizeof(uint4) / sizeof(T)); + for (int i = 0; i < strideInVector; i++) + { + uint4 values = sharedMem[id * strideInVector + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.2f %.2f ", float(aa), float(bb)); + } + else + { + printf("%f ", bit_cast(value)); + } + } + } + printf("\n"); + } +} + +void serialRead(uint tid, RWStructuredBuffer outputBuffer) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + for (int i = 0; i < Size; i++) + { + DType value = outputBuffer[i]; + printf("%.4f ", float(value)); + } + printf("\n"); +} + +void test(uint tid, uint resIndex) +{ + __target_switch + { + case cuda: + TestBiasSumReduce(tid); + break; + case spirv: + TestBiasSumReduce(tid); + break; + } + + // serialRead(tid, outputBuffer); + + int subgroupIndex = tid / SubgroupSize; + if (subgroupIndex != 0) + return; + + if (tid == 0) + { + bool isPassed = true; + for (int i = 0; i < OutputSize; i++) + { + DType value = outputBuffer[i]; + DType expected = DType((i + 1) * 0.01 * WorkgroupSize); + if (abs(value - expected) > DType(0.001)) + { + isPassed = false; + break; + } + } + resultBuffer[resIndex] = isPassed ? 1 : 0; + } +} + +void cleanOutputBuffer(uint tid) +{ + for (int i = 0; i < MaxOutputSize; i++) + { + outputBuffer[i] = DType(0.0f); + } + AllMemoryBarrierWithGroupSync(); +} + +[shader("compute")] +[numthreads(WorkgroupSize, 1, 1)] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + cleanOutputBuffer(tid); + test<3>(tid, 0); + AllMemoryBarrierWithGroupSync(); + } + + { + cleanOutputBuffer(tid); + test<15>(tid, 1); + AllMemoryBarrierWithGroupSync(); + } + + { + cleanOutputBuffer(tid); + test<23>(tid, 2); + AllMemoryBarrierWithGroupSync(); + } + + { + cleanOutputBuffer(tid); + test<47>(tid, 3); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/common.slang b/tests/neural/common.slang new file mode 100644 index 0000000000..8c32acf817 --- /dev/null +++ b/tests/neural/common.slang @@ -0,0 +1,64 @@ +#pragma warning(disable: 41017) + + +public bool equals(T a, T b) where T : __BuiltinFloatingPointType +{ + return abs(a - b) < T(1e-2f); +} + +public void initBias(uint tid, RWStructuredBuffer biasBuffer) + where T : __BuiltinFloatingPointType +{ + if (tid < Size) + biasBuffer[tid] = T(tid); + + AllMemoryBarrierWithGroupSync(); +} + +public bool collectResults(uint tid, bool perThreadResult) +{ + static groupshared bool s_resultMem[BatchSize / 32]; + + // each warp will check the result + bool res = WaveActiveAllTrue(perThreadResult); + + int subgroupIndex = tid / 32; + int laneId = WaveGetLaneIndex(); + // use first lane of each warp to write the result to the shared memory so that + // we can reduce the result further + if (laneId == 0) + { + // Note max subgroupIndex is "BatchSize/32 - 1" + s_resultMem[subgroupIndex] = res; + } + + AllMemoryBarrierWithGroupSync(); + + int warpCount = BatchSize / 32; + bool finalResult = true; + // We will use the "warpCount" threads in the first warp to reduce the result. + if (tid < warpCount) + { + finalResult = s_resultMem[tid]; + } + + finalResult = WaveActiveAllTrue(finalResult); + + return finalResult; +} + +public void setBufferOneValue(uint tid, RWStructuredBuffer buffer, T value) + where T : __BuiltinFloatingPointType +{ + int numIter = (Size + WorkgroupSize - 1) / WorkgroupSize; + for (int i = 0; i < numIter; i++) + { + int index = i * WorkgroupSize + tid; + if (index >= Size) + break; + + buffer[index] = value; + } + + AllMemoryBarrierWithGroupSync(); +} diff --git a/tests/neural/mma-helper-test-multi-warps-arbitrary-size.slang b/tests/neural/mma-helper-test-multi-warps-arbitrary-size.slang new file mode 100644 index 0000000000..342643f0a6 --- /dev/null +++ b/tests/neural/mma-helper-test-multi-warps-arbitrary-size.slang @@ -0,0 +1,245 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +import neural; +import common; +#pragma warning(disable: 41017) + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(data=[0], stride=4, count=64):name=biasBuffer +RWStructuredBuffer biasBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=512):out, name=debugBuffer +RWStructuredBuffer debugBuffer; + +static const int BatchSize = 64; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[512]; + +groupshared bool s_resultMem[BatchSize/32]; + +typealias SPtr = Ptr; + +// Initialize the weight matrix as 0.1 +void fillWeightMatrix(uint tid) +{ + uint numIter = (OutputSize + BatchSize - 1) / BatchSize; + + for (int i = 0; i < numIter; i++) + { + int rowIndex = i * BatchSize + tid; + if (rowIndex >= OutputSize) + break; + + int startIndex = rowIndex * InputSize; + for (int j = 0; j < InputSize; j++) + { + int index = startIndex + j; + inputBuffer[index] = DType(0.125f); + } + } +} + +DType[MMAHelper.Uint4AlignedM] +testMatVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + // const int uint4AlignedInputSize = MMAHelper.Uint4AlignedK; + typealias MMA = MMAHelper; + const int InSize = MMA.Uint4AlignedK; + const int OutSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = ptrB; + +#if !TEST_BIAS + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); +#else + Storage biasStorage = Storage(biasBuffer); + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, + biasStorage, Optional(0)); +#endif + + // serialRead(tid, ptrC, 32); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + int numVectors = sizeof(uint4) / sizeof(T); + for (int i = 0; i < numVectors; i++) + { + uint4 values = sharedMem[id * numVectors + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + float aa = bit_cast(value); + printf("%.1f ", aa); + } + } + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + DType[MMAHelper.Uint4AlignedM] outputVector; + + __target_switch + { + case cuda: + outputVector = testMatVecMul(tid); + break; + case spirv: + outputVector = testMatVecMul(tid); + break; + } + + bool res = true; + + for (int i = 0; i < OutputSize; i++) + { + DType scaleFactor = DType(0.125f); +#if !TEST_BIAS + DType expected = DType((tid + 1) * InputSize) * DType(scaleFactor); +#else + DType expected = DType((tid + 1) * InputSize) * DType(scaleFactor) + DType(i); +#endif + if (!equals(outputVector[i], expected)) + { + // printf("tid: %d, i: %d, expected: %.4f, actual: %.4f\n", tid, i, float(expected), float(outputVector[i])); + res = false; + break; + } + } + + bool finalResult = collectResults(tid, res); + if (tid == 0) + { + outputBuffer[resIndex] = finalResult ? 1 : 0; + } +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initBias(tid, biasBuffer); + + { + fillWeightMatrix<3, 5>(tid); + AllMemoryBarrierWithGroupSync(); + + test<3, 5>(tid, 0); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<11, 18>(tid); + AllMemoryBarrierWithGroupSync(); + + test<11, 18>(tid, 1); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<23, 9>(tid); + AllMemoryBarrierWithGroupSync(); + + test<23, 9>(tid, 2); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<35, 3>(tid); + AllMemoryBarrierWithGroupSync(); + + test<35, 3>(tid, 3); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<17, 64>(tid); + AllMemoryBarrierWithGroupSync(); + + test<17, 64>(tid, 4); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-multi-warps.slang b/tests/neural/mma-helper-test-multi-warps.slang new file mode 100644 index 0000000000..fee5e4debc --- /dev/null +++ b/tests/neural/mma-helper-test-multi-warps.slang @@ -0,0 +1,272 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +#pragma warning(disable: 41017) +import neural; +import common; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(data=[0], stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(data=[0], stride=4, count=64):name=biasBuffer +RWStructuredBuffer biasBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=64):out, name=resultBuffer +//RWStructuredBuffer resultBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=1024):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 64; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[512]; + +groupshared bool s_resultMem[BatchSize/32]; + +typealias SPtr = Ptr; + +// Initialize the weight matrix as identity matrix. +void identityWeightMatrix(uint tid) +{ + if (InputSize <= BatchSize) + { + if (tid > InputSize) + return; + + // InputSize mount of threads are enough to fill the input buffer + int index = tid * InputSize + tid; + if (Clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(1); + } + else + { + int numWrites = InputSize / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int index = row * InputSize + row; + if (Clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(1); + } + } +} + +DType[MMAHelper.Uint4AlignedM] testMatVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int InSize = MMA.Uint4AlignedK; + const int OutSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = ptrB; + +#if !TEST_BIAS + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); +#else + Storage biasStorage = Storage(biasBuffer); + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, + biasStorage, Optional(0)); +#endif + + // serialRead(tid, ptrC, 32); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + int numVectors = sizeof(uint4) / sizeof(T); + for (int i = 0; i < numVectors; i++) + { + uint4 values = sharedMem[id * numVectors + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + float aa = bit_cast(value); + printf("%.1f ", aa); + } + } + } + printf("\n"); + } +} + +bool verifyResult>(uint tid, T output) +{ + bool res = true; + for (int i = 0; i < OutputSize; i++) + { +#if !TEST_BIAS + DType expected = DType(tid + 1); +#else + DType expected = DType(tid + 1) + DType(i); +#endif + if (!equals(output[i], expected)) + { + res = false; + break; + } + } + return res; +} + +void test(uint tid, int resIndex) +{ + DType[MMAHelper.Uint4AlignedM] outputVector; + + __target_switch + { + case cuda: + outputVector = testMatVecMul(tid); + break; + case spirv: + outputVector = testMatVecMul(tid); + break; + } + + bool res = verifyResult(tid, outputVector); + bool finalResult = collectResults(tid, res); + if (tid == 0) + { + outputBuffer[resIndex] = finalResult ? 1 : 0; + } +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initBias(tid, biasBuffer); + + { + identityWeightMatrix<4, false>(tid); + AllMemoryBarrierWithGroupSync(); + + test<4, 4>(tid, 0); + + AllMemoryBarrierWithGroupSync(); + identityWeightMatrix<4, true>(tid); + AllMemoryBarrierWithGroupSync(); + } + + { + identityWeightMatrix<8, false>(tid); + AllMemoryBarrierWithGroupSync(); + + test<8, 8>(tid, 1); + + AllMemoryBarrierWithGroupSync(); + identityWeightMatrix<8, true>(tid); + AllMemoryBarrierWithGroupSync(); + } + + { + identityWeightMatrix<16, false>(tid); + AllMemoryBarrierWithGroupSync(); + + test<16, 16>(tid, 2); + + AllMemoryBarrierWithGroupSync(); + identityWeightMatrix<16, true>(tid); + AllMemoryBarrierWithGroupSync(); + } + + { + identityWeightMatrix<32, false>(tid); + AllMemoryBarrierWithGroupSync(); + + test<32, 32>(tid, 3); + AllMemoryBarrierWithGroupSync(); + + identityWeightMatrix<32, true>(tid); + AllMemoryBarrierWithGroupSync(); + } + + { + identityWeightMatrix<64, false>(tid); + AllMemoryBarrierWithGroupSync(); + + test<64, 64>(tid, 4); + } + + + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-single-warp-arbitrary-size.slang b/tests/neural/mma-helper-test-single-warp-arbitrary-size.slang new file mode 100644 index 0000000000..86f18d1053 --- /dev/null +++ b/tests/neural/mma-helper-test-single-warp-arbitrary-size.slang @@ -0,0 +1,219 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +#pragma warning(disable: 41017) + +import neural; +import common; + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(data=[0], stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(data=[0], stride=4, count=64):name=biasBuffer +RWStructuredBuffer biasBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=256):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 32; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +typealias SPtr = Ptr; + +// Initialize the weight matrix as half(0.5). +void fillWeightMatrix(uint tid) +{ + uint numIter = (OutputSize + BatchSize - 1) / BatchSize; + + for (int i = 0; i < numIter; i++) + { + int rowIndex = i * BatchSize + tid; + if (rowIndex >= OutputSize) + break; + + int startIndex = rowIndex * InputSize; + for (int j = 0; j < InputSize; j++) + { + int index = startIndex + j; + inputBuffer[index] = DType(0.5); + } + } + + if (tid == 0) + { + for (int i = 0; i < 64; i++) + { + biasBuffer[i] = DType(i); + } + } +} + + +DType[MMAHelper.Uint4AlignedM] testMatVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int InSize = MMA.Uint4AlignedK; + const int OutSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + +#if !TEST_BIAS + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); +#else + Storage biasStorage = Storage(biasBuffer); + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, + biasStorage, Optional(0)); +#endif + // serialRead(tid, ptrC, 32); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + DType[MMAHelper.Uint4AlignedM] outputVector; + __target_switch + { + case cuda: + outputVector = testMatVecMul(tid); + break; + case spirv: + outputVector = testMatVecMul(tid); + break; + } + + bool res = true; + if (tid < OutputSize) + { + for (int i = 0; i < OutputSize; i++) + { +#if !TEST_BIAS + DType expected = DType(tid + 1) * DType(InputSize) * DType(0.5); +#else + DType expected = DType(tid + 1) * DType(InputSize) * DType(0.5) + DType(i); +#endif + if (!equals(outputVector[i], expected)) + { + res = false; + break; + } + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initBias(tid, biasBuffer); + + { + fillWeightMatrix<3, 5>(tid); + GroupMemoryBarrierWithWaveSync(); + test<3, 5>(tid, 0); + } + + { + fillWeightMatrix<11, 18>(tid); + GroupMemoryBarrierWithWaveSync(); + test<11, 18>(tid, 1); + } + + { + fillWeightMatrix<23, 9>(tid); + GroupMemoryBarrierWithWaveSync(); + test<23, 9>(tid, 2); + } + + { + fillWeightMatrix<35, 3>(tid); + GroupMemoryBarrierWithWaveSync(); + test<35, 3>(tid, 3); + } + + { + fillWeightMatrix<17, 64>(tid); + GroupMemoryBarrierWithWaveSync(); + test<17, 64>(tid, 4); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-single-warp.slang b/tests/neural/mma-helper-test-single-warp.slang new file mode 100644 index 0000000000..f08bbeaa25 --- /dev/null +++ b/tests/neural/mma-helper-test-single-warp.slang @@ -0,0 +1,242 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=0 + +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 -xslang -DTEST_BIAS=1 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 -xslang -DTEST_BIAS=1 + +#pragma warning(disable: 41017) +import neural; +import common; + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + + +// TEST_INPUT:ubuffer(stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + + +// TEST_INPUT:ubuffer(stride=4, count=64):name=biasBuffer +RWStructuredBuffer biasBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=256):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 32; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +typealias SPtr = Ptr; + +// Initialize the weight matrix as identity matrix. +void identityWeightMatrix(uint tid, bool clear = false) +{ + if (InputSize <= BatchSize) + { + if (tid > InputSize) + return; + + // InputSize mount of threads are enough to fill the input buffer + int index = tid * InputSize + tid; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(1); + } + else + { + int numWrites = InputSize / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int index = row * InputSize + row; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(1); + } + } + +#if TEST_BIAS + if (tid == 0) + { + // We always fill the bias buffer with the index of the thread. + for (int i = 0; i < 64; i++) + { + biasBuffer[i] = DType(i); + } + } +#endif +} + +DType[MMAHelper.Uint4AlignedM] testMatVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int InSize = MMA.Uint4AlignedK; + const int OutSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + + // shared memory B and C can be reused. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + +#if !TEST_BIAS + let res = MMA.mma( inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); +#else + Storage biasStorage = Storage(biasBuffer); + let res = MMA.mma( inputVector, ptrA, ptrB, ptrC, storage, 0, biasStorage, Optional(0)); +#endif + + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + printf("%x ", value); + } + } + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + DType[MMAHelper.Uint4AlignedM] outputVector; + __target_switch + { + case cuda: + outputVector = testMatVecMul(tid); + break; + case spirv: + outputVector = testMatVecMul(tid); + break; + } + + bool res = true; + for (int i = 0; i < OutputSize; i++) + { +#if !TEST_BIAS + DType expected = DType(tid + 1); +#else + DType expected = DType(tid + 1) + DType(i); +#endif + if (!equals(outputVector[i], expected)) + { + res = false; + break; + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initBias(tid, biasBuffer); + + { + identityWeightMatrix<4>(tid); + GroupMemoryBarrierWithWaveSync(); + test<4, 4>(tid, 0); + identityWeightMatrix<4>(tid, true); + } + + { + identityWeightMatrix<8>(tid); + GroupMemoryBarrierWithWaveSync(); + test<8, 8>(tid, 1); + identityWeightMatrix<8>(tid, true); + } + + { + identityWeightMatrix<16>(tid); + GroupMemoryBarrierWithWaveSync(); + test<16, 16>(tid, 2); + identityWeightMatrix<16>(tid, true); + } + + { + identityWeightMatrix<32>(tid); + GroupMemoryBarrierWithWaveSync(); + test<32, 32>(tid, 3); + identityWeightMatrix<32>(tid, true); + } + + { + identityWeightMatrix<64>(tid); + GroupMemoryBarrierWithWaveSync(); + test<64, 64>(tid, 4); + identityWeightMatrix<64>(tid, true); + } + + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-transpose-multi-warps-arbitrary-size.slang b/tests/neural/mma-helper-test-transpose-multi-warps-arbitrary-size.slang new file mode 100644 index 0000000000..c50eeab06f --- /dev/null +++ b/tests/neural/mma-helper-test-transpose-multi-warps-arbitrary-size.slang @@ -0,0 +1,265 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=256):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 64; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +typealias SPtr = Ptr; + +// Initialize the weight with the pattern: +// each column is filled with -1 0 1 * column_index in a cycle. +// Initialize the weight with the pattern: +// each column is filled with -1 0 1 * column_index in a cycle. +void fillWeightMatrix(uint tid, bool clear = false) +{ + const int ROW = OutputSize; + const int COL = InputSize; + + int scaleFactor = 0; + if (ROW <= BatchSize) + { + if (tid > ROW) + return; + + // InputSize mount of threads are enough to fill the input buffer + // int index = tid * InputSize + tid; + int row = tid; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int i = 0; i < COL; i++) + { + int index = row * COL + i; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (i + 1)); + } + } + else + { + int numWrites = ROW / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int j = 0; j < COL; j++) + { + int index = row * COL + j; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (j + 1)); + } + } + } +} + +DType[MMAHelper.Uint4AlignedK] testMatTransposeVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + // const int uint4AlignedInputSize = MMAHelper.Uint4AlignedK; + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedK; + const int InSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + + for (int i = 0; i < OutputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); + // serialRead(tid, ptrB, 32); + // serialReadA(tid); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialReadA(uint tid) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int i = 0; i < OutputSize; i++) + { + printf("row: %d\n", i); + for (int j = 0; j < InputSize; j++) + { + printf("%.1f ", float(inputBuffer[i * InputSize + j])); + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + const int COL = InputSize; + const int ROW = OutputSize; + + DType[MMAHelper.Uint4AlignedK] outputVector; + __target_switch + { + case cuda: + outputVector = testMatTransposeVecMul(tid); + break; + case spirv: + outputVector = testMatTransposeVecMul(tid); + break; + } + + bool res = true; + int scaleFactor = (OutputSize % 3 == 0) ? 0 : -1; + for (int i = 0; i < COL; i++) + { + DType expected = DType(scaleFactor) * DType((i + 1) * (tid + 1)); + if (!equals(outputVector[i], expected)) + { + // printf( "tid: %d, expected: %.1f, actual: %.1f\n", tid, expected, float(outputVector[i])); + res = false; + break; + } + } + + bool finalResult = collectResults(tid, res); + + // Finally, use the first thread to write the result to the output buffer. + if (tid == 0) + { + outputBuffer[resIndex] = finalResult ? 1 : 0; + } +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + fillWeightMatrix<3, 5>(tid); + GroupMemoryBarrierWithGroupSync(); + test<3, 5>(tid, 0); + GroupMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<11, 18>(tid); + GroupMemoryBarrierWithGroupSync(); + test<11, 18>(tid, 1); + GroupMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<23, 9>(tid); + GroupMemoryBarrierWithGroupSync(); + test<23, 9>(tid, 2); + GroupMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<35, 3>(tid); + GroupMemoryBarrierWithGroupSync(); + test<35, 3>(tid, 3); + GroupMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<17, 64>(tid); + GroupMemoryBarrierWithGroupSync(); + test<17, 64>(tid, 4); + GroupMemoryBarrierWithGroupSync(); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-transpose-multi-warps.slang b/tests/neural/mma-helper-test-transpose-multi-warps.slang new file mode 100644 index 0000000000..36762b35ef --- /dev/null +++ b/tests/neural/mma-helper-test-transpose-multi-warps.slang @@ -0,0 +1,226 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=4096):out, name=debugBuffer +RWStructuredBuffer debugBuffer; + +static const int BatchSize = 64; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +groupshared bool s_resultMem[BatchSize/32]; + +typealias SPtr = Ptr; + +void fillWeightMatrix(uint tid, bool clear = false) +{ + const int ROW = OutputSize; + const int COL = InputSize; + + int scaleFactor = 0; + if (ROW <= BatchSize) + { + if (tid > ROW) + return; + + // InputSize mount of threads are enough to fill the input buffer + // int index = tid * InputSize + tid; + int row = tid; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int i = 0; i < COL; i++) + { + int index = row * COL + i; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (i + 1)); + } + } + else + { + int numWrites = ROW / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int j = 0; j < COL; j++) + { + int index = row * COL + j; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (j + 1)); + } + } + } +} + +DType[MMAHelper.Uint4AlignedK] testMatTransposeVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedK; + const int InSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + + for (int i = 0; i < OutputSize; i++) + { + inputVector[i] = DType(tid + 1) * DType(0.125f); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); + // serialRead(tid, ptrB, 16); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + const int COL = InputSize; + const int ROW = OutputSize; + + DType[MMAHelper.Uint4AlignedK] outputVector; + __target_switch + { + case cuda: + outputVector = testMatTransposeVecMul(tid); + break; + case spirv: + outputVector = testMatTransposeVecMul(tid); + break; + } + + bool res = true; + int scaleFactor = (OutputSize % 3 == 0) ? 0 : -1; + for (int i = 0; i < COL; i++) + { + int intVal = (i + 1) * (tid + 1); + DType expected = DType(scaleFactor) * DType(intVal) * DType(0.125f); + if (!equals(outputVector[i], expected)) + { + res = false; + break; + } + } + + bool finalResult = collectResults(tid, res); + + // Finally, use the first thread to write the result to the output buffer. + if (tid == 0) + { + outputBuffer[resIndex] = finalResult ? 1 : 0; + } +} + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + fillWeightMatrix<4, 4>(tid); + AllMemoryBarrierWithGroupSync(); + test<4, 4>(tid, 0); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<8, 8>(tid); + AllMemoryBarrierWithGroupSync(); + test<8, 8>(tid, 1); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<16, 16>(tid); + AllMemoryBarrierWithGroupSync(); + test<16, 16>(tid, 2); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<32, 32>(tid); + AllMemoryBarrierWithGroupSync(); + test<32, 32>(tid, 3); + AllMemoryBarrierWithGroupSync(); + } + + { + fillWeightMatrix<64, 64>(tid); + AllMemoryBarrierWithGroupSync(); + test<64, 64>(tid, 4); + } + + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-transpose-single-warp-arbitrary-size.slang b/tests/neural/mma-helper-test-transpose-single-warp-arbitrary-size.slang new file mode 100644 index 0000000000..837e33d08f --- /dev/null +++ b/tests/neural/mma-helper-test-transpose-single-warp-arbitrary-size.slang @@ -0,0 +1,252 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + + +// Make the weight matrix a 64x64 matrix in row major order +// TEST_INPUT:ubuffer(data=[0], stride=2, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=256):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 32; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +typealias SPtr = Ptr; + +// Initialize the weight with the pattern: +// each column is filled with -1 0 1 * column_index in a cycle. +void fillWeightMatrix(uint tid, bool clear = false) +{ + const int ROW = OutputSize; + const int COL = InputSize; + + int scaleFactor = 0; + if (ROW <= BatchSize) + { + if (tid > ROW) + return; + + // InputSize mount of threads are enough to fill the input buffer + // int index = tid * InputSize + tid; + int row = tid; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int i = 0; i < COL; i++) + { + int index = row * COL + i; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (i + 1)); + } + } + else + { + int numWrites = ROW / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int j = 0; j < COL; j++) + { + int index = row * COL + j; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (j + 1)); + } + } + } +} + +DType[MMAHelper.Uint4AlignedK] testMatTransposeVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedK; + const int InSize = MMA.Uint4AlignedM; + DType inputVector[InSize] = {}; + + // Since this is transpose test, OutputSize is the actual input size. + for (int i = 0; i < OutputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + + let res = MMA.mma(inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); + // serialReadA(tid); + // serialRead(tid, ptrC, 32); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialReadA(uint tid) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int i = 0; i < OutputSize; i++) + { + printf("row: %d\n", i); + for (int j = 0; j < InputSize; j++) + { + printf("%.1f ", float(inputBuffer[i * InputSize + j])); + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + const int COL = InputSize; + const int ROW = OutputSize; + + DType[MMAHelper.Uint4AlignedK] outputVector; + __target_switch + { + case cuda: + outputVector = testMatTransposeVecMul(tid); + break; + case spirv: + outputVector = testMatTransposeVecMul(tid); + break; + } + + bool res = true; + int scaleFactor = (OutputSize % 3 == 0) ? 0 : -1; + for (int i = 0; i < COL; i++) + { + DType expected = DType(scaleFactor * (i + 1.0f) * (tid + 1.0f)); + if (!equals(outputVector[i], expected)) + { + res = false; + break; + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + fillWeightMatrix<3, 5>(tid); + GroupMemoryBarrierWithWaveSync(); + test<3, 5>(tid, 0); + } + + { + fillWeightMatrix<11, 18>(tid); + GroupMemoryBarrierWithWaveSync(); + test<11, 18>(tid, 1); + } + + { + fillWeightMatrix<23, 9>(tid); + GroupMemoryBarrierWithWaveSync(); + test<23, 9>(tid, 2); + } + + { + fillWeightMatrix<35, 3>(tid); + GroupMemoryBarrierWithWaveSync(); + test<35, 3>(tid, 3); + } + + { + fillWeightMatrix<17, 64>(tid); + GroupMemoryBarrierWithWaveSync(); + test<17, 64>(tid, 4); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/mma-helper-test-transpose-single-warp.slang b/tests/neural/mma-helper-test-transpose-single-warp.slang new file mode 100644 index 0000000000..4358f59cf8 --- /dev/null +++ b/tests/neural/mma-helper-test-transpose-single-warp.slang @@ -0,0 +1,253 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + + +// Make the weight matrix a 64x64 matrix in row major order + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// TEST_INPUT:ubuffer(stride=4, count=4096):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=5):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +// // TEST_INPUT:ubuffer(stride=4, count=256):out, name=debugBuffer +// RWStructuredBuffer debugBuffer; + +static const int BatchSize = 32; +static const int SubgroupSize = 32; + +// Allocate large enough shared memory in this test. +groupshared uint4 s_sharedMemoryA[256]; +groupshared uint4 s_sharedMemoryB[256]; + +typealias SPtr = Ptr; +typealias HalfPtr = Ptr; + +// Initialize the weight with the pattern: +// each column is filled with -1 0 1 * column_index in a cycle. +void fillWeightMatrix(uint tid, bool clear = false) +{ + // int scaleFactor = (tid % 3 == 0) ? -1 : ((tid % 3 == 1) ? 0 : 1); + int scaleFactor = 0; + if (OutputSize <= BatchSize) + { + if (tid > OutputSize) + return; + + // InputSize mount of threads are enough to fill the input buffer + // int index = tid * InputSize + tid; + int row = tid; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int i = 0; i < InputSize; i++) + { + int index = row * InputSize + i; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (i + 1)); + } + } + else + { + int numWrites = InputSize / BatchSize; + for (int i = 0; i < numWrites; i++) + { + int row = tid + i * BatchSize; + int scaleFactor = (row % 3 == 0) ? -1 : ((row % 3 == 1) ? 0 : 1); + for (int j = 0; j < InputSize; j++) + { + int index = row * InputSize + j; + if (clear) + inputBuffer[index] = DType(0); + else + inputBuffer[index] = DType(scaleFactor * (j + 1)); + } + } + } +} + +DType[MMAHelper.Uint4AlignedK] testMatTransposeVecMul(uint tid) +{ + typealias Storage = StructuredBufferStorage; + Storage storage = Storage(inputBuffer); + + // Construct the input vector as follow: + // x = tid + 1 + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedK; + const int InSize = MMA.Uint4AlignedM; + + DType inputVector[InSize] = {}; + + for (int i = 0; i < OutputSize; i++) + { + inputVector[i] = DType(tid + 1); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + // In this test, we can safely share the shMemB and shMemC, because they are always used at different time. + // And they are always the same size. + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + SPtr ptrC = __getAddress(s_sharedMemoryB[0]); + + let res = MMA.mma( inputVector, ptrA, ptrB, ptrC, storage, 0, none, none); + return res; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem, uint rowOrColumnCount) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + int linearIndex = 0; + for (int id = 0; id < rowOrColumnCount; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 8; i++) + { + uint4 values = sharedMem[id * 8 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + + half aa = bit_cast(a); + half bb = bit_cast(b); + // debugBuffer[linearIndex++] = aa; + // debugBuffer[linearIndex++] = bb; + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialReadA(uint tid) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int i = 0; i < 64; i++) + { + printf("row: %d\n", i); + for (int j = 0; j < 64; j++) + { + printf("%.1f ", float(inputBuffer[i * 64 + j])); + } + printf("\n"); + } +} + +void test(uint tid, int resIndex) +{ + DType[MMAHelper.Uint4AlignedK] outputVector; + __target_switch + { + case cuda: + outputVector = testMatTransposeVecMul(tid); + break; + case spirv: + outputVector = testMatTransposeVecMul(tid); + break; + } + + bool res = true; + int scaleFactor = (OutputSize % 3 == 0 ) ? 0 : -1; + for (int i = 0; i < InputSize; i++) + { + DType expected = DType(scaleFactor * (i + 1.0f) * (tid + 1.0f)); + if (!equals(outputVector[i], expected)) + { + res = false; + break; + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + + +[numthreads(BatchSize, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + fillWeightMatrix<4, 4>(tid); + GroupMemoryBarrierWithWaveSync(); + test<4, 4>(tid, 0); + fillWeightMatrix<4, 4>(tid, true); + } + + { + fillWeightMatrix<8, 8>(tid); + GroupMemoryBarrierWithWaveSync(); + test<8, 8>(tid, 1); + fillWeightMatrix<8, 8>(tid, true); + } + + { + fillWeightMatrix<15, 15>(tid); + GroupMemoryBarrierWithWaveSync(); + test<15, 15>(tid, 2); + fillWeightMatrix<15, 15>(tid, true); + } + + { + fillWeightMatrix<32, 32>(tid); + GroupMemoryBarrierWithWaveSync(); + test<32, 32>(tid, 3); + fillWeightMatrix<32, 32>(tid, true); + } + + { + fillWeightMatrix<64, 64>(tid); + GroupMemoryBarrierWithWaveSync(); + test<64, 64>(tid, 4); + fillWeightMatrix<64, 64>(tid, true); + } + + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/outerproduct-accumulate-test-arbitrary-size.slang b/tests/neural/outerproduct-accumulate-test-arbitrary-size.slang new file mode 100644 index 0000000000..d545a219b1 --- /dev/null +++ b/tests/neural/outerproduct-accumulate-test-arbitrary-size.slang @@ -0,0 +1,220 @@ + +// On Vulkan, we can only test float type for now because atomicAdd is not supported for half on our CI machine. +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + +// TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4, count=4):out, name=resultBuffer +RWStructuredBuffer resultBuffer; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// TEST_INPUT: ubuffer(stride=4, count=4096):name=outputBuffer +RWStructuredBuffer outputBuffer; + +typealias BufferStorage = StructuredBufferStorage; + +static const int SubgroupSize = 32; +static const int WorkgroupSize = 128; +static const int workgroupCount = WorkgroupSize / SubgroupSize; + +static const int MaxK = 64; +static const int MaxM = 64; + +static const int ShMemSizeA = MaxM * 2; +static const int ShMemSizeB = MaxK * 2; + +groupshared uint4 s_sharedMemoryA[ShMemSizeA * workgroupCount]; +groupshared uint4 s_sharedMemoryB[ShMemSizeB * workgroupCount * (sizeof(DType) / sizeof(half))]; + +typealias SPtr = Ptr; + +// Basic test on MatMul without bias, this test covers both forward and backward pass +void TestOuterProductAccumulate(uint tid) +{ + BufferStorage storage = BufferStorage(outputBuffer); + + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedM; + const int InSize = MMA.Uint4AlignedK; + + DType dOutVector[OutSize] = {}; + DType inputVector[InSize] = {}; + + float scaleFactor = 1.0f / WorkgroupSize; + + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType((i + 1) * (tid + 1) * scaleFactor); + } + + for (int i = 0; i < OutputSize; i++) + { + dOutVector[i] = DType((OutputSize - i) * scaleFactor); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + + // perform dOut OPA Input + MMA.outerProductAccumulate(ptrA, ptrB, dOutVector, inputVector, storage, 0); + + // serialRead<48>(tid, ptrB); + // serialRead(tid, outputBuffer); +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int id = 0; id < 16; id++) + { + printf("tid: %d\n", id); + int strideInVector = Stride / (sizeof(uint4) / sizeof(half)); + for (int i = 0; i < strideInVector; i++) + { + uint4 values = sharedMem[id * strideInVector + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + printf("%f ", bit_cast(value)); + } + } + } + printf("\n"); + } +} + +void serialRead(uint tid, RWStructuredBuffer outputBuffer) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + for (int i = 0; i < Row; i++) + { + printf("row: %d\n", i); + for (int j = 0; j < Column; j++) + { + DType value = outputBuffer[i * Column + j]; + printf("%.1f ", value); + } + printf("\n"); + } +} + +void test(uint tid, uint resIndex) +{ + __target_switch + { + case cuda: + TestOuterProductAccumulate(tid); + break; + case spirv: + TestOuterProductAccumulate(tid); + break; + } + + int subgroupIndex = tid / SubgroupSize; + if (subgroupIndex != 0) + return; + + // We just use the first warp to check the result to simplicity + bool res = true; + int numIterPerThread = (OutputSize + SubgroupSize - 1) / SubgroupSize; + float scaleFactor = 1.0f / WorkgroupSize; + scaleFactor *= scaleFactor; + scaleFactor = scaleFactor * ((1 + WorkgroupSize) * WorkgroupSize/2.0f); + + // Note that half type is not precise, and more accumulate will cause more error, so + // we increase the error threshold by using the scale of the workgroup size which is exactly + // the number of accumulations. + // This is a purely empirical value, good input pattern could make the error smaller. + DType errorThreshold = DType(WorkgroupSize * 0.016); + + for (int i = 0; i < numIterPerThread; i++) + { + int rowIdx = i * SubgroupSize + tid; + if (rowIdx >= OutputSize) + break; + + int startVal = OutputSize - rowIdx; + + // each thread will check one row + for (int j = 0; j < InputSize; j++) + { + DType expected = DType(startVal * scaleFactor * (j + 1)); + DType actual = outputBuffer[rowIdx * InputSize + j]; + if (abs(expected - actual) > errorThreshold) + { + res = false; + printf("tid: %d, rowIdx: %d, j: %d, expected: %.4f, actual: %.4f\n", tid, rowIdx, j, float(expected), float(actual)); + break; + } + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + resultBuffer[resIndex] = res ? 1 : 0; +} + +[shader("compute")] +[numthreads(WorkgroupSize, 1, 1)] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + setBufferOneValue<3 * 15, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<3, 15>(tid, 0); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER: 1 + + { + setBufferOneValue<20 * 30, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<20, 30>(tid, 1); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER-NEXT: 1 + + { + setBufferOneValue<49 * 17, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<49, 17>(tid, 2); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER-NEXT: 1 + + { + setBufferOneValue<33 * 9, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<33, 9>(tid, 3); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/outerproduct-accumulate-test.slang b/tests/neural/outerproduct-accumulate-test.slang new file mode 100644 index 0000000000..6101529150 --- /dev/null +++ b/tests/neural/outerproduct-accumulate-test.slang @@ -0,0 +1,219 @@ + +// On Vulkan, we can only test float type for now because atomicAdd is not supported for half on our CI machine. +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 + +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -xslang -experimental-feature -xslang -DTEST_HALF=1 + +#pragma warning(disable: 41017) +import neural; +import common; + +// TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4, count=4):out, name=resultBuffer +RWStructuredBuffer resultBuffer; + + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// TEST_INPUT: ubuffer(stride=4, count=4096):name=outputBuffer +RWStructuredBuffer outputBuffer; + +typealias BufferStorage = StructuredBufferStorage; + +static const int SubgroupSize = 32; +static const int WorkgroupSize = 128; +static const int workgroupCount = WorkgroupSize / SubgroupSize; + +static const int MaxK = 64; +static const int MaxM = 64; + +static const int ShMemSizeA = MaxM * 2; +static const int ShMemSizeB = MaxK * 2; + +groupshared uint4 s_sharedMemoryA[ShMemSizeA * workgroupCount]; +groupshared uint4 s_sharedMemoryB[ShMemSizeB * workgroupCount * sizeof(DType)/sizeof(half)]; + +typealias SPtr = Ptr; + +// Basic test on MatMul without bias, this test covers both forward and backward pass +void TestOuterProductAccumulate(uint tid) +{ + BufferStorage storage = BufferStorage(outputBuffer); + + typealias MMA = MMAHelper; + const int OutSize = MMA.Uint4AlignedM; + const int InSize = MMA.Uint4AlignedK; + + DType dOutVector[OutSize] = {}; + DType inputVector[InSize] = {}; + + float scaleFactor = 1.0f / WorkgroupSize; + + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = DType((i + 1) * (tid + 1) * scaleFactor); + } + + for (int i = 0; i < OutputSize; i++) + { + dOutVector[i] = DType((OutputSize - i) * scaleFactor); + } + + SPtr ptrA = __getAddress(s_sharedMemoryA[0]); + SPtr ptrB = __getAddress(s_sharedMemoryB[0]); + + // perform dOut OPA Input + MMA.outerProductAccumulate(ptrA, ptrB, dOutVector, inputVector, storage, 0); + + // serialRead<32, DType>(tid, ptrB); + // serialRead(tid, outputBuffer); +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + for (int id = 0; id < 16; id++) + { + printf("tid: %d\n", id); + int strideInVector = Stride / (sizeof(uint4) / sizeof(T)); + for (int i = 0; i < strideInVector; i++) + { + uint4 values = sharedMem[id * strideInVector + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + if (sizeof(T) == 2) + { + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + else + { + printf("%f ", bit_cast(value)); + } + } + } + printf("\n"); + } +} + +void serialRead(uint tid, RWStructuredBuffer outputBuffer) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + for (int i = 0; i < Row; i++) + { + for (int j = 0; j < Column; j++) + { + DType value = outputBuffer[i * Column + j]; + printf("%.1f ", float(value)); + } + printf("\n"); + } +} + +void test(uint tid, uint resIndex) +{ + __target_switch + { + case cuda: + TestOuterProductAccumulate(tid); + break; + case spirv: + TestOuterProductAccumulate(tid); + break; + } + + int subgroupIndex = tid / SubgroupSize; + if (subgroupIndex != 0) + return; + + // We just use the first warp to check the result to simplicity + bool res = true; + int numIterPerThread = (OutputSize + SubgroupSize - 1) / SubgroupSize; + float scaleFactor = 1.0f / WorkgroupSize; + scaleFactor *= scaleFactor; + scaleFactor = scaleFactor * ((1 + WorkgroupSize) * WorkgroupSize/2.0f); + + // Note that half type is not precise, and more accumulate will cause more error, so + // we increase the error threshold by using the scale of the workgroup size which is exactly + // the number of accumulations. + // This is a purely empirical value, good input pattern could make the error smaller. + DType errorThreshold = DType(WorkgroupSize * 0.016); + + for (int i = 0; i < numIterPerThread; i++) + { + int rowIdx = i * SubgroupSize + tid; + if (rowIdx >= OutputSize) + break; + + int startVal = OutputSize - rowIdx; + + // each thread will check one row + for (int j = 0; j < InputSize; j++) + { + DType expected = DType(startVal * scaleFactor * (j + 1)); + DType actual = outputBuffer[rowIdx * InputSize + j]; + if (abs(expected - actual) > errorThreshold) + { + res = false; + // printf("tid: %d, rowIdx: %d, j: %d, expected: %.4f, actual: %.4f\n", tid, rowIdx, j, float(expected), float(actual)); + break; + } + } + } + + res = WaveActiveAllTrue(res); + if (tid == 0) + resultBuffer[resIndex] = res ? 1 : 0; +} + +[shader("compute")] +[numthreads(WorkgroupSize, 1, 1)] +void computeMain(uint tid : SV_DispatchThreadID) +{ + { + setBufferOneValue<8 * 8, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<8, 8>(tid, 0); + AllMemoryBarrierWithGroupSync(); + } + + { + setBufferOneValue<16 * 16, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<16, 16>(tid, 1); + AllMemoryBarrierWithGroupSync(); + } + + { + setBufferOneValue<32 * 32, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<32, 32>(tid, 2); + AllMemoryBarrierWithGroupSync(); + } + + { + setBufferOneValue<64 * 64, WorkgroupSize>(tid, outputBuffer, DType(0.0f)); + test<64, 64>(tid, 3); + AllMemoryBarrierWithGroupSync(); + } + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/shared-memory-size.slang b/tests/neural/shared-memory-size.slang new file mode 100644 index 0000000000..e8766f255e --- /dev/null +++ b/tests/neural/shared-memory-size.slang @@ -0,0 +1,235 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=0 +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTEST_HALF=1 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=0 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTEST_HALF=1 + +// Test that given the layer configuration, thread count, data type, subgroup size, and execution mode, the shared memory required is correctly +// calculated during compile time. + +// SharedMemorySize.OfHiddenN.Bytes +// is the compile time (or link time) constant that represents the total shared memory needed for launcing the inference kernel or training kernel. + +import neural; +#pragma warning(disable: 41017) + +#if TEST_HALF +typealias DType = half; +#else +typealias DType = float; +#endif + +// TEST_INPUT:ubuffer(stride=4, count=16):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + +static const int BatchSize = 64; +static const int SubgroupSize = 32; + +// Compute the shared memory size to verify the output. +// This verify function uses naive implementation to calculate the shared memory size at run time. +// It just calculate the shared memory size for a single layer, and store the result in the result array. +// Then we will find the maximum value in the result array. +void computeOneLayer>(T inputSize, U outputSize, inout uint index, out Arr result) + where T == uint + where U == uint +{ + // In this test, the CoopMat shape is always 16x16x16. + const int coopMatShape = 16; + const int alignedK = ((inputSize + coopMatShape - 1) / coopMatShape) * coopMatShape; + const int alignedM = ((outputSize + coopMatShape - 1) / coopMatShape) * coopMatShape; + const int aliggnedN = SubgroupSize; + const int subgroupCount = BatchSize / SubgroupSize; + + const int tileASize = Mode == ExecutionMode.Inference ? + alignedM * coopMatShape * sizeof(half) : // Forward, A is M x K, tile A is M x 16 + max(alignedM, alignedK) * coopMatShape * sizeof(half); // Backward, A is K x M, tile A is K x 16 + + // in backward, B is M x N, and N x K, tile B is 16 x max(N, K) + const int tileBSize = Mode == ExecutionMode.Inference ? + coopMatShape * aliggnedN * sizeof(half) : // forward, B is K x N, tile B is 16 x N + coopMatShape * (max(aliggnedN, alignedK)) * sizeof(half); // backward, B is M x N, tile B is 16 x max(N, K) + + const int tileCSize = tileBSize * sizeof(DType) / sizeof(half); + + const int maxTileBCSize = tileBSize > tileCSize ? tileBSize : tileCSize; + const int expectedSize = tileASize + maxTileBCSize * subgroupCount; + + result[index++] = expectedSize; +} + + +uint verifySize(expand each T a, expand each U b) + where T == uint + where U == uint +{ + uint index = 0; + uint resBuffer[countof(T)]; + expand computeOneLayer(each a, each b, index, resBuffer); + + int maxSize = 0; + + [ForceUnroll] + for (int i = 0; i < countof(T); i++) + { + maxSize = max(maxSize, resBuffer[i]); + } + + return maxSize; +} + +// randomly generated layer configurations + +// hidden layer count = 0 +#define TEST_CASE_1 5U, 9U +#define TEST_CASE_1_PAIR_1 5U +#define TEST_CASE_1_PAIR_2 9U + +// hidden layer count = 1 +#define TEST_CASE_2 17U, 23U, 3U +#define TEST_CASE_2_PAIR_1 17U, 23U +#define TEST_CASE_2_PAIR_2 23U, 3U + +// hidden layer count = 2 +#define TEST_CASE_3 1U, 2U, 3U, 4U +#define TEST_CASE_3_PAIR_1 1U, 2U, 3U +#define TEST_CASE_3_PAIR_2 2U, 3U, 4U + +// hidden layer count = 3 +#define TEST_CASE_4 7U, 18U, 29U, 42U, 57U +#define TEST_CASE_4_PAIR_1 7U, 18U, 29U, 42U +#define TEST_CASE_4_PAIR_2 18U, 29U, 42U, 57U + +// hidden layer count = 4 +#define TEST_CASE_5 3U, 11U, 24U, 37U, 49U, 62U +#define TEST_CASE_5_PAIR_1 3U, 11U, 24U, 37U, 49U +#define TEST_CASE_5_PAIR_2 11U, 24U, 37U, 49U, 62U + +// hidden layer count = 5 +#define TEST_CASE_6 38U, 7U, 59U, 14U, 46U, 3U, 27U +#define TEST_CASE_6_PAIR_1 38U, 7U, 59U, 14U, 46U, 3U +#define TEST_CASE_6_PAIR_2 7U, 59U, 14U, 46U, 3U, 27U + +// hidden layer count = 6 +#define TEST_CASE_7 41U, 6U, 58U, 13U, 29U, 2U, 47U, 35U +#define TEST_CASE_7_PAIR_1 41U, 6U, 58U, 13U, 29U, 2U, 47U +#define TEST_CASE_7_PAIR_2 6U, 58U, 13U, 29U, 2U, 47U, 35U + +// hidden layer count = 7 +#define TEST_CASE_8 3U, 21U, 45U, 12U, 56U, 7U, 39U, 50U, 8U +#define TEST_CASE_8_PAIR_1 3U, 21U, 45U, 12U, 56U, 7U, 39U, 50U +#define TEST_CASE_8_PAIR_2 21U, 45U, 12U, 56U, 7U, 39U, 50U, 8U + +// hidden layer count = 8 +#define TEST_CASE_9 11U, 27U, 4U, 33U, 58U, 19U, 42U, 6U, 31U, 22U +#define TEST_CASE_9_PAIR_1 11U, 27U, 4U, 33U, 58U, 19U, 42U, 6U, 31U +#define TEST_CASE_9_PAIR_2 27U, 4U, 33U, 58U, 19U, 42U, 6U, 31U, 22U + +// hidden layer count = 9 +#define TEST_CASE_10 5U, 48U, 2U, 37U, 14U, 51U, 9U, 28U, 41U, 17U, 60U +#define TEST_CASE_10_PAIR_1 5U, 48U, 2U, 37U, 14U, 51U, 9U, 28U, 41U, 17U +#define TEST_CASE_10_PAIR_2 48U, 2U, 37U, 14U, 51U, 9U, 28U, 41U, 17U, 60U + +// hidden layer count = 10 +#define TEST_CASE_11 8U, 36U, 50U, 3U, 22U, 57U, 14U, 44U, 6U, 31U, 12U, 59U +#define TEST_CASE_11_PAIR_1 8U, 36U, 50U, 3U, 22U, 57U, 14U, 44U, 6U, 31U, 12U +#define TEST_CASE_11_PAIR_2 36U, 50U, 3U, 22U, 57U, 14U, 44U, 6U, 31U, 12U, 59U + +// hidden layer count = 11 +#define TEST_CASE_12 7U, 49U, 18U, 5U, 33U, 41U, 2U, 28U, 11U, 54U, 19U, 47U, 23U +#define TEST_CASE_12_PAIR_1 7U, 49U, 18U, 5U, 33U, 41U, 2U, 28U, 11U, 54U, 19U, 47U +#define TEST_CASE_12_PAIR_2 49U, 18U, 5U, 33U, 41U, 2U, 28U, 11U, 54U, 19U, 47U, 23U + +// hidden layer count = 12 +#define TEST_CASE_13 12U, 3U, 56U, 7U, 21U, 49U, 15U, 28U, 5U, 37U, 11U, 44U, 6U, 30U +#define TEST_CASE_13_PAIR_1 12U, 3U, 56U, 7U, 21U, 49U, 15U, 28U, 5U, 37U, 11U, 44U, 6U +#define TEST_CASE_13_PAIR_2 3U, 56U, 7U, 21U, 49U, 15U, 28U, 5U, 37U, 11U, 44U, 6U, 30U + +// hidden layer count = 13 +#define TEST_CASE_14 9U, 27U, 41U, 6U, 32U, 18U, 50U, 4U, 39U, 12U, 23U, 55U, 7U, 34U, 16U +#define TEST_CASE_14_PAIR_1 9U, 27U, 41U, 6U, 32U, 18U, 50U, 4U, 39U, 12U, 23U, 55U, 7U, 34U +#define TEST_CASE_14_PAIR_2 27U, 41U, 6U, 32U, 18U, 50U, 4U, 39U, 12U, 23U, 55U, 7U, 34U, 16U + +// hidden layer count = 14 +#define TEST_CASE_15 2U, 45U, 10U, 51U, 8U, 36U, 19U, 5U, 28U, 43U, 12U, 56U, 7U, 34U, 21U, 48U +#define TEST_CASE_15_PAIR_1 2U, 45U, 10U, 51U, 8U, 36U, 19U, 5U, 28U, 43U, 12U, 56U, 7U, 34U, 21U +#define TEST_CASE_15_PAIR_2 45U, 10U, 51U, 8U, 36U, 19U, 5U, 28U, 43U, 12U, 56U, 7U, 34U, 21U, 48U + +// hidden layer count = 15 +#define TEST_CASE_16 14U, 3U, 49U, 8U, 37U, 6U, 28U, 11U, 41U, 2U, 33U, 17U, 55U, 9U, 44U, 12U, 29U +#define TEST_CASE_16_PAIR_1 14U, 3U, 49U, 8U, 37U, 6U, 28U, 11U, 41U, 2U, 33U, 17U, 55U, 9U, 44U, 12U +#define TEST_CASE_16_PAIR_2 3U, 49U, 8U, 37U, 6U, 28U, 11U, 41U, 2U, 33U, 17U, 55U, 9U, 44U, 12U, 29U + +#define RunTest(N) \ +do { \ + uint expectedValue = verifySize(TEST_CASE_##N##_PAIR_1, TEST_CASE_##N##_PAIR_2); \ + uint actualValue = SharedMemorySizeInference.OfLayer##N.Bytes; \ + bool isCorrect = expectedValue == actualValue; \ + actualValue = SharedMemorySizeTraining.OfLayer##N.Bytes; \ + expectedValue = verifySize(TEST_CASE_##N##_PAIR_1, TEST_CASE_##N##_PAIR_2); \ + isCorrect = isCorrect && (expectedValue == actualValue); \ + outputBuffer[N-1] = isCorrect ? 1U : 0U; \ +} while (false) + + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint tid: SV_DispatchThreadID) +{ + __target_switch + { + case cuda: + typealias SharedMemorySizeInference = SharedMemorySize; + typealias SharedMemorySizeTraining = SharedMemorySize< DType, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, BatchSize/SubgroupSize>; + RunTest(1); + RunTest(2); + RunTest(3); + RunTest(4); + RunTest(5); + RunTest(6); + RunTest(7); + RunTest(8); + RunTest(9); + RunTest(10); + RunTest(11); + RunTest(12); + RunTest(13); + RunTest(14); + RunTest(15); + RunTest(16); + case spirv: + typealias SharedMemorySizeInference = SharedMemorySize; + typealias SharedMemorySizeTraining = SharedMemorySize< DType, TargetEnum.SPIR_V, ExecutionMode.Training, SubgroupSize, BatchSize/SubgroupSize>; + RunTest(1); + RunTest(2); + RunTest(3); + RunTest(4); + RunTest(5); + RunTest(6); + RunTest(7); + RunTest(8); + RunTest(9); + RunTest(10); + RunTest(11); + RunTest(12); + RunTest(13); + RunTest(14); + RunTest(15); + RunTest(16); + } + + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-aligned.slang b/tests/neural/tiled-mma-load-test-aligned.slang new file mode 100644 index 0000000000..3fb844c863 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-aligned.slang @@ -0,0 +1,246 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// +// Since tileA is 32×32 bytes, each thread loads 16 elements per row. +// test1: Expected output for the first tile: +// +// 0 1 2 3 ... 15 +// 32 33 34 35 ... 47 +// ... +// 960 961 962 963 ... 975 +// 992 993 994 995 ... 1007 +// +// test2: Expected output for the second tile: +// +// 16 17 18 19 ... 31 +// 48 49 50 51 ... 63 +// ... +// 1008 1009 1010 1011 ... 1023 + +// ============================================================================ +// MATRIX B - Loaded from Thread-Local Vector +// ============================================================================ +// Each thread constructs a local vector where: +// element[i] = tid * 32 + i, for i = 0, 1, 2, ... 31 +// +// Since tileB is 32×32 bytes, each thread loads 16 elements per column. +// test1: Expected output for the first tile (column-major): +// +// 0 32 ... 960 992 +// 1 33 ... 961 993 +// 2 34 ... 962 994 +// 3 35 ... 963 995 +// ... +// 15 47 ... 975 1007 + +// test2: Expected output for the second tile (column-major): +// +// 16 48 ... 1008 1040 +// 17 49 ... 1009 1041 +// 18 50 ... 1010 1042 +// 19 51 ... 1011 1043 +// ... +// 31 63 ... 1023 1055 + +// Make the weight matrix as 32x32 matrix in row major order +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=4):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +static const int InputSize = 32; +static const int OutputSize = 32; +static const int SubgroupSize = 32; + +// Tile A is size of OutputSize * 32 bytes +groupshared uint4 s_sharedMemoryA[OutputSize * 2]; + +// Tile B is size of 32 x SubgroupSize bytes +groupshared uint4 s_sharedMemoryB[SubgroupSize * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + +void testLoadShB(uint tid, uint tileIndex) +{ + // Construct the input vector as follow: + // x = tid * 32 + 0, 1, 2, 3 ... 31 + half inputVector[InputSize]; + for (int i = 0; i < InputSize; i++) + { + inputVector[i] = half(tid * InputSize + i); + } + + SPtr sharedMemoryB = __getAddress(s_sharedMemoryB[0]); + // This test only has one subgroup, so the subgroup index is always 0. + MMAHelper.loadVectorToShB(sharedMemoryB, tileIndex, 0, inputVector); + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, uint size, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one row/column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + half expected = half(tid * 32 + tileIndex * 16); + bool res = true; + uint index = tid * 16; // 16 half per thread + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[tid * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + + if (aa != expected) + { + res = false; + break; + } + expected += half(1.0f); + if (bb != expected) + { + res = false; + break; + } + expected += half(1.0f); + } + } + + return res; +} + +void test1(uint tid) +{ + if (tid >= SubgroupSize) + return; + + testLoadShA(tid, 0); + + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, OutputSize, __getAddress(s_sharedMemoryA[0]), 0); + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[0] = res ? 1 : 0; + // BUFFER: 1 + + testLoadShB(tid, 0); + + bool res1 = verifiedOutput(tid, InputSize, __getAddress(s_sharedMemoryB[0]), 0); + res1 = WaveActiveAllTrue(res1); + if (tid == 0) + outputBuffer[1] = res1 ? 1 : 0; + // BUFFER-NEXT: 1 +} + +void test2(uint tid) +{ + if (tid >= SubgroupSize) + return; + + testLoadShA(tid, 1); + + bool res = verifiedOutput(tid, OutputSize, __getAddress(s_sharedMemoryA[0]), 1); + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[2] = res ? 1 : 0; + // BUFFER-NEXT: 1 + + testLoadShB(tid, 1); + + bool res1 = verifiedOutput(tid, InputSize, __getAddress(s_sharedMemoryB[0]), 1); + res1 = WaveActiveAllTrue(res1); + if (tid == 0) + outputBuffer[3] = res1 ? 1 : 0; + // BUFFER-NEXT: 1 +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + for (int id = 0; id < 32; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + __target_switch + { + case cuda: + test1(tid); + test2(tid); + break; + case spirv: + test1(tid); + test2(tid); + break; + } +} diff --git a/tests/neural/tiled-mma-load-test-transpose-aligned.slang b/tests/neural/tiled-mma-load-test-transpose-aligned.slang new file mode 100644 index 0000000000..3bc48d36f5 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-transpose-aligned.slang @@ -0,0 +1,188 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 30: 960 961 962 963 ... 975 976 977 ... 991 +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// +// In the transpose mode, each tile still load a column of CoopMat of A^T, but the +// memory layout is in column major. Therefore, each load will results in a +// AlignedK x WMMA_TileWidth (16 for half type) matrix. + +// test1: Expected output for the first tile: +// col0: 0 1 2 3 ... 31 +// col1: 32 33 34 35 ... 63 +// ... +// col15: 480 481 482 483 ... 511 + +// test2: Expected output for the second tile: +// col0: 512 513 514 515 ... 543 +// col1: 544 545 546 547 ... 575 +// ... +// col15: 992 993 994 995 ... 1023 + + +// Make the weight matrix as 32x32 matrix in row major order +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=2):out, name=outputBuffer +RWStructuredBuffer outputBuffer; + + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +static const int InputSize = 32; +static const int OutputSize = 32; +static const int SubgroupSize = 32; + +// Tile A is size of OutputSize * 32 bytes +groupshared uint4 s_sharedMemoryA[OutputSize * 2]; + +// Tile B is size of 32 x SubgroupSize bytes +groupshared uint4 s_sharedMemoryB[SubgroupSize * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + + +bool verifiedOutput(uint tid, uint size, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one column of the shared memory. + // The tile width of A is just 16 for half type. + const int WMMA_TileWidth = 16; + const int WMMA_TileHeight = 16; + if (tid >= WMMA_TileWidth) + return true; + + const int AlignedK = ((InputSize + WMMA_TileWidth - 1) / WMMA_TileWidth) * WMMA_TileWidth; + const int TileSize = AlignedK * WMMA_TileWidth; + const int ElementCountPerVector = sizeof(uint4) / sizeof(half); + const int NumVectorsPerColumn = AlignedK / ElementCountPerVector; + + half expected = half(tid * AlignedK + tileIndex * TileSize); + bool res = true; + uint indexInTile = tid * NumVectorsPerColumn; // AlignedK half per thread + + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint4 values = sharedMem[indexInTile + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + + if (aa != expected) + { + res = false; + break; + } + expected += half(1.0f); + if (bb != expected) + { + res = false; + break; + } + expected += half(1.0f); + } + } + + return res; +} + +void test(uint tid, uint tileIndex, uint resIndex) +{ + if (tid >= SubgroupSize) + return; + + __target_switch + { + case cuda: + testLoadShA(tid, tileIndex); + break; + case spirv: + testLoadShA(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, OutputSize, __getAddress(s_sharedMemoryA[0]), tileIndex); + res = WaveActiveAllTrue(res); + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + // In transpose mode, tile is in column major, and each length is AlignedK length + for (int id = 0; id < 16; id++) + { + printf("col: %d\n", id); + for (int i = 0; i < 4; i++) + { + uint4 values = sharedMem[id * 4 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + test(tid, 0, 0); + // BUFFER: 1 + test(tid, 1, 1); + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-transpose-unAligned-width-A.slang b/tests/neural/tiled-mma-load-test-transpose-unAligned-width-A.slang new file mode 100644 index 0000000000..a1766d9d73 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-transpose-unAligned-width-A.slang @@ -0,0 +1,268 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly when the height of the matrix +// is not aligned with the tile height in transpose mode. In transpose mode, the height of the matrix +// is the width of the transposed matrix, so the name of this test. +// The result of the tile will always be aligned with 16 in columns with 0 paddings. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 7: 224 225 226 227 ... 239 240 241 ... 255 +// Row 8: 256 257 258 259 ... 269 270 271 ... 287 +// ... +// Row 14: 448 449 450 451 ... 459 460 461 ... 479 +// Row 15: 480 481 482 483 ... 491 492 493 ... 511 +// ... +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + +// We will test different heights of the matrix, all the heights are smaller than 16. +// For example, if we set the height to 3, the matrix will be: +// +// 0 1 2 ... 15 +// 32 33 34 ... 47 +// 64 65 66 ... 79 +// [paddings for row 3-15] + +// the transposed matrix will be: +// 0 32 64 [paddings for columns 3-15] +// 1 33 65 [paddings for columns 3-15] +// 2 34 66 [paddings for columns 3-15] +// ... +// 15 47 79 [paddings for columns 3-15] + +// so the shared memory will be: +// col0: 0 1 2 ... 15 +// col1: 1 33 65 ... 79 +// col2: 2 34 66 ... 79 +// col3: 0 0 0 ... 0 +// ... +// col13: 0 0 0 ... 0 +// col14: 0 0 0 ... 0 +// col15: 0 0 0 ... 0 + + +// Make the weight matrix as 16x32 matrix in row major order +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=9):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +static const int InputSize = 32; +static const int SubgroupSize = 32; + +// Tile A is size of InputSize/8 x 16 in uint4. +static const int TileSize = (InputSize / 8) * 16; + +// We double the size of the shared memory to also test if there is out-of-bound write issue. +groupshared uint4 s_sharedMemoryA[TileSize * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + +void invalidateSharedMemory(uint tid, SPtr shmPtr) +{ + if (tid >= SubgroupSize) + return; + + // Initialize the shared memory with all 1.0h. + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = (TileSize * 2 + activeThreadCount - 1) / activeThreadCount; + for (int i = 0; i < numIters; i++) + { + uint index = tid * numIters + i; + if (index >= TileSize * 2) + break; + + shmPtr[index] = uint4(0x3C003C00); + } + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, uint size, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one column of the shared memory. + // The tile width of A is just 16 for half type. + const int WMMA_TileWidth = 16; + const int WMMA_TileHeight = 16; + + const int AlignedK = ((InputSize + WMMA_TileWidth - 1) / WMMA_TileWidth) * WMMA_TileWidth; + const int TileSizeInElements = AlignedK * WMMA_TileWidth; + const int ElementCountPerVector = sizeof(uint4) / sizeof(half); + const int NumVectorsPerColumn = AlignedK / ElementCountPerVector; + + // Verify the output is correct, each thread will verify one row/column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + half expected = half(tid * AlignedK + tileIndex * TileSizeInElements); + bool res = true; + + if (tid < WMMA_TileWidth) + { + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint indexInTile = tid * NumVectorsPerColumn + i; + uint4 values = sharedMem[indexInTile]; + uint4 element = values; + if (indexInTile / NumVectorsPerColumn + tileIndex * WMMA_TileWidth >= M) + { + // Checking paddings are correct for out-of-range elements. + if (!values.equals(uint4(0, 0, 0, 0))) + { + return false; + } + continue; + } + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { expected, expected + 1.0h }; + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + return false; + } + } + expected += 2.0h; + } + } + } + + + { + // For out-of-range rows, we just check if the values are same as the initialized values. + // Because we also need to check if the library accidentally write some values to the out-of-range columns. + int startColumn = (tileIndex + 1) * WMMA_TileWidth; + int startIndex = startColumn * NumVectorsPerColumn; + + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint indexInTile = (tid * NumVectorsPerColumn + i) + startIndex; + if (indexInTile >= TileSize * 2) + break; + + uint4 values = sharedMem[indexInTile]; + if (!values.equals(uint4(0x3C003C00))) + { + return false; + } + } + } + + return true; +} + +void Test(uint tid, int tileIndex, int resIndex) +{ + invalidateSharedMemory(tid, __getAddress(s_sharedMemoryA[0])); + __target_switch + { + case cuda: + testLoadShA(tid, tileIndex); + break; + case spirv: + testLoadShA(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, M, __getAddress(s_sharedMemoryA[0]), tileIndex); + res = WaveActiveAllTrue(res); + + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithWaveSync(); + + if (tid > 0) + return; + + // In transpose mode, tile is in column major, and each length is AlignedK length + for (int id = 0; id < 32; id++) + { + printf("col: %d\n", id); + for (int i = 0; i < 4; i++) + { + uint4 values = sharedMem[id * 4 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + if (tid >= SubgroupSize) + return; + + // smaller than 16 cases, it will padded to 16 rows. + Test<13>(tid, 0, 0); + Test<9>(tid, 0, 1); + Test<7>(tid, 0, 2); + Test<3>(tid, 0, 3); + Test<1>(tid, 0, 4); + + Test<19>(tid, 0, 5); // bigger than 16 case, it will padded to 32 rows. + Test<19>(tid, 1, 6); // bigger than 16 case, it will padded to 32 rows. + Test<25>(tid, 0, 7); + Test<25>(tid, 1, 8); + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-transpose-unAligned-width-height-A.slang b/tests/neural/tiled-mma-load-test-transpose-unAligned-width-height-A.slang new file mode 100644 index 0000000000..367dbfd1b5 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-transpose-unAligned-width-height-A.slang @@ -0,0 +1,273 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly when the width, height of the matrix +// is not aligned with the tile width, height. The result of the tile will always be aligned with 16 each row +// and column with 0 paddings. + +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 7: 224 225 226 227 ... 239 240 241 ... 255 +// Row 8: 256 257 258 259 ... 269 270 271 ... 287 +// ... +// Row 14: 448 449 450 451 ... 459 460 461 ... 479 +// Row 15: 480 481 482 483 ... 491 492 493 ... 511 +// ... +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + +// We will test different widths, heights of the matrix, for example, +// if the width == 19, and height == 3, and if we read the 2nd tile, the result will be: +// +// 16 17 18 ... 0 [13 zeros padding] +// 35 36 37 ... 0 [13 zeros padding] +// 54 55 56 ... 0 [13 zeros padding] +// ... +// [13 zeros padding rows] + +// Explanation: +// Think about the input matrix is a 1-D buffer from 0-1023, since the width of the matrix is 19 now, +// the stride of the matrix is 19. And height is 3, so the matrix will be: +// So the matrix will be: +// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 +// 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 +// 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 + +// Because we are reading the 2nd tile, so it will start from the 17th column (index 16), so there +// will be only 3 elements left in each row, the remaining 13 elements will be padded with 0: +// +// 16 17 18 ... 0 [13 zeros padding] +// 35 36 37 ... 0 [13 zeros padding] +// 54 55 56 ... 0 [13 zeros padding] +// ... +// [13 zeros padding rows] + + +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=7):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +void invalidateSharedMemoryA(uint tid) +{ + if (tid >= SubgroupSize) + return; + + // Initialize the shared memory with all 1.0h. + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = (TileSize * 2 + activeThreadCount - 1) / activeThreadCount; + for (int i = 0; i < numIters; i++) + { + uint index = tid * numIters + i; + if (index >= TileSize * 2) + break; + + s_sharedMemoryA[index] = uint4(0x3C003C00); + } + GroupMemoryBarrierWithWaveSync(); +} + +static const int SubgroupSize = 32; + +static const int TileSize = 4 * 16; +groupshared uint4 s_sharedMemoryA[TileSize * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one column of the shared memory. + // The tile width of A is just 16 for half type. + const int WMMA_TileWidth = 16; + + const int AlignedK = ((K + WMMA_TileWidth - 1) / WMMA_TileWidth) * WMMA_TileWidth; + const int TileSizeInElements = AlignedK * WMMA_TileWidth; + const int ElementCountPerVector = sizeof(uint4) / sizeof(half); + const int NumVectorsPerColumn = AlignedK / ElementCountPerVector; + + // Verify the output is correct, each thread will verify one row/column of the shared memory. + half expected = half(tid * K + (tileIndex * K * WMMA_TileWidth)); + bool res = true; + + if (tid < WMMA_TileWidth) + { + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint indexInTile = tid * NumVectorsPerColumn + i; + uint4 element = sharedMem[indexInTile]; + if (indexInTile / NumVectorsPerColumn + tileIndex * WMMA_TileWidth >= M) + { + // Checking paddings are correct for out-of-range elements. + if (!element.equals(uint4(0, 0, 0, 0))) + { + return false; + } + continue; + } + + uint yIndex = ElementCountPerVector * i; + + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { expected, expected + 1.0h }; + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + if (yIndex++ >= K) + { + if (actual[verifyIndex] != 0.0h) + { + return false; + } + continue; + } + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + return false; + } + } + expected += 2.0h; + } + } + } + + + { + // For out-of-range rows, we just check if the values are same as the initialized values. + // Because we also need to check if the library accidentally write some values to the out-of-range columns. + int startColumn = (tileIndex + 1) * WMMA_TileWidth; + int startIndex = startColumn * NumVectorsPerColumn; + + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint indexInTile = (tid * NumVectorsPerColumn + i) + startIndex; + if (indexInTile >= TileSize * 2) + break; + + uint4 values = sharedMem[indexInTile]; + if (!values.equals(uint4(0x3C003C00))) + { + return false; + } + } + } + + return true; +} + +void Test(uint tid, uint tileIndex, int resIndex) +{ + if (tid >= SubgroupSize) + return; + + // fill the shared memory with invalid values. + invalidateSharedMemoryA(tid); + + __target_switch + { + case cuda: + testLoadShA(tid, tileIndex); + break; + case spirv: + testLoadShA(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, __getAddress(s_sharedMemoryA[0]), tileIndex); + res = WaveActiveAllTrue(res); + + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + if (tid > 0) + return; + + uint AlignedK = ((K + 16 - 1) / 16) * 16; + uint NumVectorsPerColumn = AlignedK / 8; + + for (int col = 0; col < 16; col++) + { + printf("col: %d\n", col); + for (int i = 0; i < NumVectorsPerColumn; i++) + { + uint4 values = sharedMem[col * NumVectorsPerColumn + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + Test<3, 19>(tid, 0, 0); // M == 3, K == 19, read tile 1 + // BUFFER: 1 + + Test<19, 5>(tid, 0, 1); // M == 19, K == 5, read tile 0 + // BUFFER-NEXT: 1 + + Test<19, 5>(tid, 1, 2); // M == 19, K == 5, read tile 1 + // BUFFER-NEXT: 1 + + Test<5, 32>(tid, 0, 3); // M == 5, K == 32, read tile 0 + // BUFFER-NEXT: 1 + + Test<14, 13>(tid, 0, 4); // M == 14, K == 13, read tile 0 + // BUFFER-NEXT: 1 + + Test<25, 16>(tid, 0, 5); // M == 25, K == 16, read tile 0 + // BUFFER-NEXT: 1 + + Test<25, 16>(tid, 1, 6); // M == 25, K == 16, read tile 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-unAligned-height-A.slang b/tests/neural/tiled-mma-load-test-unAligned-height-A.slang new file mode 100644 index 0000000000..27a1b300fc --- /dev/null +++ b/tests/neural/tiled-mma-load-test-unAligned-height-A.slang @@ -0,0 +1,223 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly when the height of the matrix +// is not aligned with the tile height. The result of the tile will always be aligned with 16 in columns +// with 0 paddings. +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 7: 224 225 226 227 ... 239 240 241 ... 255 +// Row 8: 256 257 258 259 ... 269 270 271 ... 287 +// ... +// Row 14: 448 449 450 451 ... 459 460 461 ... 479 +// Row 15: 480 481 482 483 ... 491 492 493 ... 511 +// ... +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + +// We will test different heights of the matrix, all the heights are smaller than 16. +// For example, if we set the height to 3, the matrix will be: +// +// 0 1 2 ... 15 +// 32 33 34 ... 47 +// 64 65 66 ... 79 +// [paddings for row 3-15] + +// Make the weight matrix as 16x32 matrix in row major order +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=7):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +static const int InputSize = 32; +static const int SubgroupSize = 32; + +// Tile A is size of 32 * 32 bytes: TileHeight is at least 16. +groupshared uint4 s_sharedMemoryA[32 * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + +void invalidateSharedMemory(uint tid, SPtr shmPtr) +{ + if (tid >= SubgroupSize) + return; + + // Initialize the shared memory with all 1.0h. + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = (32 * 2 + activeThreadCount - 1) / activeThreadCount; + for (int i = 0; i < numIters; i++) + { + uint index = tid * numIters + i; + if (index >= 32 * 2) + break; + + shmPtr[index] = uint4(0x3C003C00); + } + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, uint size, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one row/column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + half expected = half(tid * 32 + tileIndex * 16); + bool res = true; + uint index = tid * 16; // 16 half per thread + uint alignedM = ((M + 16 - 1) / 16) * 16; + + if (tid < alignedM) + { + for (int i = 0; i < 2; i++) + { + uint indexInTile = tid * 2 + i; + uint4 values = sharedMem[indexInTile]; + uint4 element = values; + if (indexInTile / 2 >= M) + { + // Checking paddings are correct for out-of-range elements. + if (!values.equals(uint4(0, 0, 0, 0))) + { + return false; + } + continue; + } + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { expected, expected + 1.0h }; + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + return false; + } + } + expected += 2.0h; + } + } + } + else + { + // For out-of-range rows, we just check if the values are same as the initialized values. + // Because we also need to check if the library accidentally write some values to the out-of-range rows. + for (int i = 0; i < 2; i++) + { + uint indexInTile = tid * 2 + i; + uint4 values = sharedMem[indexInTile]; + if (!values.equals(uint4(0x3C003C00))) + { + return false; + } + } + } + + return true; +} + +void Test(uint tid, int resIndex) +{ + invalidateSharedMemory(tid, __getAddress(s_sharedMemoryA[0])); + __target_switch + { + case cuda: + testLoadShA(tid, 0); + break; + case spirv: + testLoadShA(tid, 0); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, M, __getAddress(s_sharedMemoryA[0]), 0); + res = WaveActiveAllTrue(res); + + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + if (tid > 0) + return; + + for (int id = 0; id < 32; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + if (tid >= SubgroupSize) + return; + + // smaller than 16 cases, it will padded to 16 rows. + Test<13>(tid, 0); + Test<9>(tid, 1); + Test<7>(tid, 2); + Test<3>(tid, 3); + Test<1>(tid, 4); + + Test<19>(tid, 5); // bigger than 16 case, it will padded to 32 rows. + Test<25>(tid, 6); + // BUFFER: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-unAligned-height-B.slang b/tests/neural/tiled-mma-load-test-unAligned-height-B.slang new file mode 100644 index 0000000000..e13e3c841c --- /dev/null +++ b/tests/neural/tiled-mma-load-test-unAligned-height-B.slang @@ -0,0 +1,209 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly when the height of the matrix +// is not aligned with the tile height. The result of the tile will always be aligned with 16 in columns +// with 0 paddings. +// +// We will construct the input vector as follow: +// x = tid * 32 + 0, 1, 2, 3 ... 31 + +// We will test different heights of the matrix B. +// For example, if we set the height to 19, and read the second tile, the matrix will be: +// +// 16 48 80 ... 1008 +// 17 49 81 ... 1009 +// 18 50 82 ... 1010 +// [zeros rows for height 3-15] + +// Explanation: According the way of how we construct the input vector, the matrix will be: +// +// 0 32 64 ... 992 +// 1 33 65 ... 993 +// 2 34 66 ... 994 +// ... +// 31 65 97 ... 1023 +// in major column. + +// So if we set height to 19, the matrix will be: +// +// 0 32 64 ... 992 +// 1 33 65 ... 993 +// 2 34 66 ... 994 +// ... +// 18 48 78 ... 1008 +// [zeros rows for height 19-31] + +// Since we are reading the second tile, it will start from the 16th row, so the matrix will be: +// +// 16 48 80 ... 1008 +// 17 49 81 ... 1009 +// 18 50 82 ... 1010 +// [zeros rows for height 3-15] + + +// TEST_INPUT:ubuffer(stride=4, count=5):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +static const int OutputSize = 32; +static const int SubgroupSize = 32; + +// Tile A is size of 32 * 32 bytes: TileHeight is at least 16. +groupshared uint4 s_sharedMemoryB[32 * 2]; + +typealias SPtr = Ptr; + +void testLoadShB(uint tid, uint tileIndex) +{ + half inputVector[MMAHelper.Uint4AlignedK] = {}; + for (int i = 0; i < K; i++) + { + inputVector[i] = half(tid * 32 + i); + } + + SPtr sharedMemoryB = __getAddress(s_sharedMemoryB[0]); + // This test only has one subgroup, so the subgroup index is always 0. + MMAHelper.loadVectorToShB(sharedMemoryB, tileIndex, 0, inputVector); + GroupMemoryBarrierWithWaveSync(); +} + +void invalidateSharedMemory(uint tid, SPtr shmPtr) +{ + if (tid >= SubgroupSize) + return; + + // Initialize the shared memory with all 1.0h. + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = (32 * 2 + activeThreadCount - 1) / activeThreadCount; + for (int i = 0; i < numIters; i++) + { + uint index = tid * numIters + i; + if (index >= 32 * 2) + break; + + shmPtr[index] = uint4(0x3C003C00); + } + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + half expected = half(tid * 32 + tileIndex * 16); + bool res = true; + uint index = tid * 16; // 16 half per thread + + for (int i = 0; i < 2; i++) + { + uint indexInTile = tid * 2 + i; + uint4 values = sharedMem[indexInTile]; + uint4 element = values; + + for (int j = 0; j < 4; j++) + { + uint yIndex = (i * 8 + j * 2) + tileIndex * 16; + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { expected, expected + 1.0h }; + + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + // Check out-of-boundary for X direction. + if (yIndex++ >= K) + { + if (actual[verifyIndex] != 0.0h) + { + return false; + } + continue; + } + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + return false; + } + } + expected += 2.0h; + } + } + + + return true; +} + +void Test(uint tid, uint tileIndex, int resIndex) +{ + invalidateSharedMemory(tid, __getAddress(s_sharedMemoryB[0])); + __target_switch + { + case cuda: + testLoadShB(tid, tileIndex); + break; + case spirv: + testLoadShB(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryB[0])); + + bool res = verifiedOutput(tid, __getAddress(s_sharedMemoryB[0]), tileIndex); + res = WaveActiveAllTrue(res); + + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + if (tid > 0) + return; + + for (int id = 0; id < 32; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + if (tid >= SubgroupSize) + return; + + Test<19>(tid, 0, 0); // arbitrary case, K = 19, read tile 0 + // BUFFER: 1 + + Test<19>(tid, 1, 1); // arbitrary case, K = 19, read tile 1 + // BUFFER-NEXT: 1 + + Test<16>(tid, 0, 2); // boundary case, K = 16, read tile 0 + // BUFFER-NEXT: 1 + + Test<32>(tid, 0, 3); // boundary case, K = 32, read tile 0 + // BUFFER-NEXT: 1 + + Test<32>(tid, 1, 4); // boundary case, K = 32, read tile 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-unAligned-width-B-multi-warps.slang b/tests/neural/tiled-mma-load-test-unAligned-width-B-multi-warps.slang new file mode 100644 index 0000000000..19c6f96310 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-unAligned-width-B-multi-warps.slang @@ -0,0 +1,286 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly -xslang -DTHREAD_COUNT=64 +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature -xslang -DTHREAD_COUNT=64 + +import neural; +#pragma warning(disable: 41017) + +// Similar to the test tiled-mma-load-test-unAligned-width-B.slang, but this test is testing for multi-warps case. + +// If we launch N warps, where N > 1, it's sure that the first N-1 warps are full of 32 threads. But the N-th warp +// may be incomplete, so we will check the last warp to make sure that the tiled matrix load operations work correctly. + +// In addition, each subgroup will load its own tile to the shared memory cross the whole workgroup. So each subgroup +// will own its own contiguous chunk of the shared memory offset by TILE_SIZE. So we also need to check the loading is +// correct. + +// +// Construct the thread local vector as follow: +// x = tid * 32 + 0, 1, 2, 3 ... 31 + +// So the whole matrix B (Assume THREAD_COUNT = 40) will be: +// Column 0 1 2 ... 31 32 ... 39 +// 0 32 64 ... 992 1024 ... 1248 +// 1 33 65 ... 993 1025 ... 1249 +// 2 34 66 ... 994 1026 ... 1250 +// ... +// 31 63 95 ... 995 1027 ... 1251 +// 32 64 96 ... 996 1028 ... 1252 +// +// There are totally 2 warps launched, the first warp will load the first 32 columns, and the second warp will load the remaining 8 columns. +// The first warp will load the following columns: +// Column 0 1 2 ... 31 +// 0 32 64 ... 992 +// 1 33 65 ... 993 +// 2 34 66 ... 994 +// ... +// 31 63 95 ... 995 +// +// The second warp will also load the remaining 32 columns. +// Column 32 32 ... 63 +// 1024 1056 ... 2016 +// 1025 1057 ... 2017 +// 1026 1058 ... 2018 +// ... +// 1055 1087 ... 2047 + +// We only need to check the last warp. As long as all the values are correct, we can assume that the first N-1 warps are correct +// as well. + + +// Make the weight matrix as 32x32 matrix in row major order +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=2):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// InputSize is the K dimension of the matrix A, make it unaligned to test the padding logic. +static const int InputSize = 32; // K +static const int OutputSize = 32; // M +static const int SubgroupSize = 32; // N - Note, N dimension is up to the subgroup size. + +static const int SUBGROUP_COUNT = THREAD_COUNT / 32; + +// Tile B is size of 16 * 32 bytes. +static const int TILE_SIZE = 16 * 2; +static const int N_TILES_PER_SUBGROUP = 2; +groupshared uint4 s_sharedMemoryB[SUBGROUP_COUNT * N_TILES_PER_SUBGROUP * TILE_SIZE]; + +typealias SPtr = Ptr; + +void invalidateSharedMemory(uint tid, SPtr shmPtr) +{ + // Initialize the shared memory with all 1.0h. + // Each subgroup will initialize the shared memory for its own threads. + uint subgroupIndex = tid / 32; + uint laneId = WaveGetLaneIndex(); + uint offset = subgroupIndex * (TILE_SIZE * N_TILES_PER_SUBGROUP); + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = ((TILE_SIZE * N_TILES_PER_SUBGROUP) + activeThreadCount - 1) / activeThreadCount; + uint offsetPerThread = activeThreadCount; + + for (int i = 0; i < numIters; i++) + { + uint index = i * offsetPerThread + laneId; + if (index >= TILE_SIZE * N_TILES_PER_SUBGROUP) + break; + + shmPtr[offset + index] = uint4(0x3C003C00); + } + + GroupMemoryBarrierWithWaveSync(); +} + +void testLoadShB(uint tid, uint tileIndex) +{ + half inputVector[InputSize]; + float value = tid * 32; + for (int i = 0; i < 32; i++) + { + inputVector[i] = half(value + i); + } + + SPtr sharedMemoryB = __getAddress(s_sharedMemoryB[0]); + + uint subgroupIndex = tid / 32; + MMAHelper.loadVectorToShB(sharedMemoryB, tileIndex, subgroupIndex, inputVector); + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint laneId, uint subgroupIndex, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one row/column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + // !!! IMPORTANT: the expected value is a float, not a half, because when the number is bigger than 2048, + // the half type + 1.0h will always 2048. But by using float, half(2048 + 1.0f) will be 2048.0h, but + // float(2048 + 2.0f) will be 2050.0f. This is how to construct the input vector in testLoadShB function. + // So using the same method to verify the output. + float expected = (subgroupIndex * 1024) + laneId * 32 + tileIndex * 16; + bool res = true; + uint index = laneId * 16; // 16 half per thread + uint alignedN = ((OutputSize + 16 - 1) / 16) * 16; + + // Because of our setting, the laneId will never be larger than alignedN + { + for (int i = 0; i < 2; i++) + { + uint indexInTile = laneId * 2 + i; + uint4 values = sharedMem[indexInTile]; + uint4 element = values; + + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { half(expected), half(expected + 1.0f) }; + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + // printf("subgroup %d, actual: %.1f, expected: %.1f\n", subgroupIndex, float(actual[verifyIndex]), float(expectedValues[verifyIndex])); + return false; + } + } + expected += 2.0h; + } + } + } + + // Next we need to check if paddings are correct. + { + uint activeThreadCount = WaveActiveCountBits(true); + uint startColumn = activeThreadCount; + uint remainningColumns = alignedN - startColumn; + // Let each thread check one column. + uint columnsPerThread = (remainningColumns + activeThreadCount - 1) / activeThreadCount; + for (uint i = 0; i < columnsPerThread; i++) + { + uint columnIndex = startColumn + laneId; + if (columnIndex >= alignedN) + break; + + uint index = columnIndex * 2; + if (!sharedMem[index].equals(uint4(0)) || !sharedMem[index+1].equals(uint4(0))) + { + return false; + } + } + } + + // Next we need to check if there is any accidental load out of the range. + if (alignedN < 32) + { + uint activeThreadCount = WaveActiveCountBits(true); + uint startColumn = alignedN; + uint remainningColumns = 32 - startColumn; + // Let each thread check one column. + uint columnsPerThread = (remainningColumns + activeThreadCount - 1) / activeThreadCount; + for (uint i = 0; i < columnsPerThread; i++) + { + uint columnIndex = startColumn + laneId; + if (columnIndex >= 32) + break; + uint index = columnIndex * 2; + if (!sharedMem[index].equals(uint4(0x3C003C00)) || !sharedMem[index+1].equals(uint4(0x3C003C00))) + { + return false; + } + } + } + + return true; +} + +groupshared bool s_verifiedOutput[THREAD_COUNT/32]; +void Test(uint tid, uint tileIndex, int resIndex) +{ + // Waiting for all threads to finish the previous test as this is multi-warps test. + // So need to sync whole workgroup. + GroupMemoryBarrierWithGroupSync(); + invalidateSharedMemory(tid, __getAddress(s_sharedMemoryB[0])); + + // Each warp will load its own tile. + __target_switch + { + case cuda: + testLoadShB(tid, tileIndex); + break; + case spirv: + testLoadShB(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryB[0])); + + uint subgroupIndex = tid / 32; + int laneId = WaveGetLaneIndex(); + uint offset = subgroupIndex * (TILE_SIZE * N_TILES_PER_SUBGROUP); + bool res = verifiedOutput(laneId, subgroupIndex, __getAddress(s_sharedMemoryB[offset]), tileIndex); + res = WaveActiveAllTrue(res); + + // Write verified result to the shared memory. + if (laneId == 0) + s_verifiedOutput[subgroupIndex] = res; + + // Wait for all threads to finish writing. + GroupMemoryBarrierWithGroupSync(); + + // Read verified result from the shared memory. + if (tid == 0) + { + for (int i = 0; i < THREAD_COUNT/32; i++) + { + if (!s_verifiedOutput[i]) + { + outputBuffer[resIndex] = 0; + return; + } + } + outputBuffer[resIndex] = 1; + } +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + GroupMemoryBarrierWithGroupSync(); + + if (tid > 0) + return; + + uint columnsPerTile = 16; + for (int id = 0; id < SUBGROUP_COUNT * N_TILES_PER_SUBGROUP * columnsPerTile; id++) + { + printf("col: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +RWStructuredBuffer outputBuffer12; + +[numthreads(THREAD_COUNT, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + Test(tid, 0, 0); // arbitrary case, N = THREAD_COUNT, read tile 0 + // BUFFER: 1 + + Test(tid, 1, 1); // arbitrary case, N = THREAD_COUNT, read tile 1 + // BUFFER-NEXT: 1 +} diff --git a/tests/neural/tiled-mma-load-test-unAligned-width-height-A.slang b/tests/neural/tiled-mma-load-test-unAligned-width-height-A.slang new file mode 100644 index 0000000000..8517bfc1c8 --- /dev/null +++ b/tests/neural/tiled-mma-load-test-unAligned-width-height-A.slang @@ -0,0 +1,248 @@ +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly +// TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature + +import neural; +#pragma warning(disable: 41017) + +// This test verifies that the tiled MMA load operations work correctly when the width, height of the matrix +// is not aligned with the tile width, height. The result of the tile will always be aligned with 16 each row +// and column with 0 paddings. + +// +// ============================================================================ +// MATRIX A - Loaded from Global Memory +// ============================================================================ +// We construct a 32x32 matrix in row-major order: +// +// Column: 0 1 2 3 ... 15 16 17 ... 31 +// Row 0: 0 1 2 3 ... 15 16 17 ... 31 +// Row 1: 32 33 34 35 ... 47 48 49 ... 63 +// ... +// Row 7: 224 225 226 227 ... 239 240 241 ... 255 +// Row 8: 256 257 258 259 ... 269 270 271 ... 287 +// ... +// Row 14: 448 449 450 451 ... 459 460 461 ... 479 +// Row 15: 480 481 482 483 ... 491 492 493 ... 511 +// ... +// Row 31: 992 993 994 995 ... 1007 1008 1009 ... 1023 +// + +// We will test different widths, heights of the matrix, for example, +// if the width == 19, and height == 3, and if we read the 2nd tile, the result will be: +// +// 16 17 18 ... 0 [13 zeros padding] +// 35 36 37 ... 0 [13 zeros padding] +// 54 55 56 ... 0 [13 zeros padding] +// ... +// [13 zeros padding rows] + +// Explanation: +// Think about the input matrix is a 1-D buffer from 0-1023, since the width of the matrix is 19 now, +// the stride of the matrix is 19. And height is 3, so the matrix will be: +// So the matrix will be: +// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 +// 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 +// 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 + +// Because we are reading the 2nd tile, so it will start from the 17th column (index 16), so there +// will be only 3 elements left in each row, the remaining 13 elements will be padded with 0: +// +// 16 17 18 ... 0 [13 zeros padding] +// 35 36 37 ... 0 [13 zeros padding] +// 54 55 56 ... 0 [13 zeros padding] +// ... +// [13 zeros padding rows] + + +// TEST_INPUT:ubuffer(stride=2, count=1024):name=inputBuffer +RWStructuredBuffer inputBuffer; + +// TEST_INPUT:ubuffer(stride=4, count=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void initWeightMatrix(uint tid) +{ + inputBuffer[tid] = half(tid); +} + +void invalidateSharedMemoryA(uint tid) +{ + if (tid >= SubgroupSize) + return; + + // Initialize the shared memory with all 1.0h. + uint activeThreadCount = WaveActiveCountBits(true); + uint numIters = (32 * 2 + activeThreadCount - 1) / activeThreadCount; + for (int i = 0; i < numIters; i++) + { + uint index = tid * numIters + i; + if (index >= 32 * 2) + break; + + s_sharedMemoryA[index] = uint4(0x3C003C00); + } + GroupMemoryBarrierWithWaveSync(); +} + +static const int SubgroupSize = 32; + +// Tile A is size of 32 * 32 bytes: TileHeight is at least 16. +groupshared uint4 s_sharedMemoryA[32 * 2]; + +typealias SPtr = Ptr; + +void testLoadShA(uint tid, uint tileIndex) +{ + typealias Storage = StructuredBufferStorage; + + Storage storage = Storage(inputBuffer); + + SPtr sharedMemoryA = __getAddress(s_sharedMemoryA[0]); + MMAHelper.loadShA(sharedMemoryA, tileIndex, storage, 0); + GroupMemoryBarrierWithWaveSync(); +} + +bool verifiedOutput(uint tid, SPtr sharedMem, uint tileIndex) +{ + // Verify the output is correct, each thread will verify one row/column of the shared memory. + // So each thread will check 2 uint4 elements (32 bytes/16 half) in the shared memory. + half expected = half(tid * K + tileIndex * 16); + bool res = true; + uint index = tid * 16; // 16 half per thread + + // We only need to verify alignedM rows + uint alignedM = ((M + 16 - 1) / 16) * 16; + + if (tid < alignedM) + { + for (int i = 0; i < 2; i++) + { + uint indexInTile = tid * 2 + i; + uint4 values = sharedMem[indexInTile]; + uint4 element = values; + if (indexInTile / 2 >= M) + { + // Check out-of-boundary for Y direction. + if (!values.equals(uint4(0, 0, 0, 0))) + { + return false; + } + continue; + } + + for (int j = 0; j < 4; j++) + { + uint xIndex = (i * 8 + j * 2) + tileIndex * 16; + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + half actual[2] = { bit_cast(a), bit_cast(b) }; + half expectedValues[2] = { expected, expected + 1.0h }; + for (int verifyIndex = 0; verifyIndex < 2; verifyIndex++) + { + if (xIndex++ >= K) + { + if (actual[verifyIndex] != 0.0h) + { + return false; + } + continue; + } + if (actual[verifyIndex] != expectedValues[verifyIndex]) + { + return false; + } + } + expected += 2.0h; + } + } + } + else + { + // For out-of-range rows, we just check if the values are same as the initialized values. + // Because we also need to check if the library accidentally write some values to the out-of-range rows. + for (int i = 0; i < 2; i++) + { + uint indexInTile = tid * 2 + i; + uint4 values = sharedMem[indexInTile]; + if (!values.equals(uint4(0x3C003C00))) + { + return false; + } + } + } + return true; +} + +void Test(uint tid, uint tileIndex, int resIndex) +{ + if (tid >= SubgroupSize) + return; + + // fill the shared memory with invalid values. + invalidateSharedMemoryA(tid); + + __target_switch + { + case cuda: + testLoadShA(tid, tileIndex); + break; + case spirv: + testLoadShA(tid, tileIndex); + break; + } + // serialRead(tid, __getAddress(s_sharedMemoryA[0])); + + bool res = verifiedOutput(tid, __getAddress(s_sharedMemoryA[0]), tileIndex); + res = WaveActiveAllTrue(res); + + if (tid == 0) + outputBuffer[resIndex] = res ? 1 : 0; +} + +// This function is just used for debugging, not for verification. So keep it here. +void serialRead(uint tid, SPtr sharedMem) +{ + if (tid > 0) + return; + + for (int id = 0; id < 32; id++) + { + printf("tid: %d\n", id); + for (int i = 0; i < 2; i++) + { + uint4 values = sharedMem[id * 2 + i]; + uint4 element = values; + for (int j = 0; j < 4; j++) + { + uint value = element[j]; + uint16_t a = (uint16_t)(value & 0xFFFF); + uint16_t b = (uint16_t)((value >> 16) & 0xFFFF); + + half aa = bit_cast(a); + half bb = bit_cast(b); + printf("%.1f %.1f ", float(aa), float(bb)); + } + } + printf("\n"); + } +} + +[numthreads(1024, 1, 1)] +[shader("compute")] +void computeMain(uint tid : SV_DispatchThreadID) +{ + initWeightMatrix(tid); + + Test<3, 19>(tid, 1, 0); // M == 3, K = 19, read tile 1 + // BUFFER: 1 + + Test<5, 32>(tid, 1, 1); // M == 5, K = 32, read tile 1 + // BUFFER-NEXT: 1 + + Test<14, 13>(tid, 0, 2); // M == 14, K = 13, read tile 0 + // BUFFER-NEXT: 1 + + Test<25, 16>(tid, 0, 3); // M == 25, K = 16, read tile 0 + // BUFFER-NEXT: 1 +} diff --git a/tests/skip-list-debug.txt b/tests/skip-list-debug.txt new file mode 100644 index 0000000000..ad469cbb46 --- /dev/null +++ b/tests/skip-list-debug.txt @@ -0,0 +1,37 @@ +# Skip list for debug builds +# These tests are skipped to reduce CI time in debug builds. +# Each line is a path prefix - tests starting with these paths will be skipped. + +# Skip most neural tests in debug builds, except for a few basic ones: +# - basic-inline-vector-test.slang (basic inline vector functionality) +# - basic-coopmat-vector-test.slang (basic cooperative matrix vector) +# - mma-helper-test-single-warp.slang (basic MMA single warp) +# - neural-module-discovery-diagnose.slang (module discovery diagnostics) + +# TODO: Those test expose the issue that our compiler is very slow in debug mode + when specializing complicated generics. We might need to optimize this + in the future. File an issue to track the this #9755. + +tests/neural/bias-sum-reduce.slang +tests/neural/common.slang +tests/neural/mma-helper-test-multi-warps-arbitrary-size.slang +tests/neural/mma-helper-test-multi-warps.slang +tests/neural/mma-helper-test-single-warp-arbitrary-size.slang +tests/neural/mma-helper-test-single-warp.slang +tests/neural/mma-helper-test-transpose-multi-warps-arbitrary-size.slang +tests/neural/mma-helper-test-transpose-multi-warps.slang +tests/neural/mma-helper-test-transpose-single-warp-arbitrary-size.slang +tests/neural/mma-helper-test-transpose-single-warp.slang +tests/neural/neural-module-discovery-diagnose.slang +tests/neural/outerproduct-accumulate-test-arbitrary-size.slang +tests/neural/outerproduct-accumulate-test.slang +tests/neural/shared-memory-size.slang +tests/neural/test1.slang +tests/neural/tiled-mma-load-test-aligned.slang +tests/neural/tiled-mma-load-test-transpose-aligned.slang +tests/neural/tiled-mma-load-test-transpose-unAligned-width-A.slang +tests/neural/tiled-mma-load-test-transpose-unAligned-width-height-A.slang +tests/neural/tiled-mma-load-test-unAligned-height-A.slang +tests/neural/tiled-mma-load-test-unAligned-height-B.slang +tests/neural/tiled-mma-load-test-unAligned-width-B-multi-warps.slang +tests/neural/tiled-mma-load-test-unAligned-width-height-A.slang diff --git a/tools/slang-test/README.md b/tools/slang-test/README.md index 7e9d01a082..69cba971d2 100644 --- a/tools/slang-test/README.md +++ b/tools/slang-test/README.md @@ -80,6 +80,8 @@ Available APIs: - `-generate-hlsl-baselines`: Generate HLSL test baselines - `-emit-spirv-via-glsl`: Emit SPIR-V through GLSL instead of directly - `-expected-failure-list `: Specify file containing expected failures +- `-skip-list `: Specify file containing tests to skip (path prefixes) +- `-exclude-prefix `: Exclude tests with specified path prefix ## Test Types diff --git a/tools/slang-test/options.cpp b/tools/slang-test/options.cpp index 051e9bcd23..ee8113ce34 100644 --- a/tools/slang-test/options.cpp +++ b/tools/slang-test/options.cpp @@ -86,6 +86,7 @@ static bool _isSubCommand(const char* arg) " -skip-reference-image-generation Skip generating reference images for render tests\n" " -emit-spirv-via-glsl Emit SPIR-V through GLSL instead of directly\n" " -expected-failure-list Specify file containing expected failures\n" + " -skip-list Specify file containing tests to skip (path prefixes)\n" " -use-shared-library Run tests in-process using shared library\n" " -use-test-server Run tests using test server\n" " -use-fully-isolated-test-server Run each test in isolated server\n" @@ -498,6 +499,39 @@ static bool _isSubCommand(const char* arg) } } } + else if (strcmp(arg, "-skip-list") == 0) + { + if (argCursor == argEnd) + { + stdError.print("error: expected operand for '%s'\n", arg); + showHelp(stdError); + return SLANG_FAIL; + } + auto fileName = *argCursor++; + String text; + File::readAllText(fileName, text); + List lines; + StringUtil::split(text.getUnownedSlice(), '\n', lines); + for (auto line : lines) + { + // Remove comments (everything after '#' character) + auto trimmedLine = line; + auto commentIndex = line.indexOf('#'); + if (commentIndex != -1) + { + trimmedLine = line.head(commentIndex); + } + + // Trim whitespace and skip empty lines + trimmedLine = trimmedLine.trim(); + if (trimmedLine.getLength() > 0) + { + Slang::StringBuilder sb; + Slang::Path::simplify(trimmedLine, Slang::Path::SimplifyStyle::NoRoot, sb); + optionsOut->skipList.add(sb); + } + } + } else if (strcmp(arg, "-test-dir") == 0) { if (argCursor == argEnd) diff --git a/tools/slang-test/options.h b/tools/slang-test/options.h index e7ffe0f971..4fa2403744 100644 --- a/tools/slang-test/options.h +++ b/tools/slang-test/options.h @@ -141,6 +141,7 @@ struct Options Slang::HashSet capabilities; Slang::HashSet expectedFailureList; + Slang::List skipList; // Ignore abort message dialog popup on Windows bool ignoreAbortMsg = false; diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index e3e28ada29..9748ee8a35 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -4839,6 +4839,22 @@ static bool shouldRunTest(TestContext* context, String filePath) } } + // Check skip list - if any entry matches, skip the test + for (auto& skipEntry : context->options.skipList) + { + if (filePath.startsWith(skipEntry)) + { + if (context->options.verbosity == VerbosityLevel::Verbose) + { + context->getTestReporter()->messageFormat( + TestMessageType::Info, + "%s file is skipped because it is found in the skip list\n", + filePath.getBuffer()); + } + return false; + } + } + if (!context->options.testPrefixes.getCount()) { return true;