diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index a01bd0e67e..332cc99674 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, int shm, void* stream); +// Launches a cooperative kernel function (grid-level sync) +extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream); + // --- From passes/pass_manager.cc --- // Applies custom graph optimization passes extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); @@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) { metax_cinn_impl.module_unload = MetaxModuleUnload; metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel; // 6. Register Compilation Strategy interface metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index b65f73e6e4..4a4cfc5571 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -39,6 +39,7 @@ namespace metax { // ============================================================ static const char* kMacaRuntimeSource = R"MACA_SOURCE( #pragma once +#include #include #include @@ -780,12 +781,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64( CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value); } +// =============================================================== +// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync()) +// Uses a sense-reversing barrier so it works correctly when called +// multiple times within the same kernel. +// REQUIREMENT: all thread blocks must be co-resident on the GPU. +// =============================================================== +__device__ unsigned int __cinn_grid_barrier_count[8192]; +__device__ unsigned int __cinn_grid_barrier_flag[8192]; + +__device__ inline void __cinn_grid_sync() { + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + unsigned int expected = + atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u); + unsigned int arrived = + atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u; + if (arrived == (unsigned int)gridDim.y) { + atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u); + __threadfence(); + atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected); + __threadfence(); + } else { + while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) == + expected) { + } + } + } + __syncthreads(); +} + #define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ - DTYPE tmp_val = init_value; \ - for (int y = 0; y < gridDim.y; y++) { \ - tmp_val = \ - cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ - } \ + cooperative_groups::this_grid().sync(); \ + DTYPE tmp_val = init_value; \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ + } \ return tmp_val; #define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ @@ -799,7 +831,26 @@ EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO) -EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO) +// FP16 grid reduce: accumulate in FP32 to avoid precision loss when summing +// multiple FP16 block-level partial sums. Each partial sum can have magnitude +// O(block_size * input_scale), and accumulating N such values in FP16 incurs +// error proportional to N * magnitude * eps_fp16. Using FP32 for the inter- +// block accumulation step keeps the error at FP16 quantization level only. +#define CINN_GRID_REDUCE_FP16_MACRO(FP16_TYPE, FP32_FUNC, INIT_VAL) \ + __device__ inline float16 cinn_grid_reduce_##FP16_TYPE( \ + const float16 *mem, int spatial_size, int spatial_index) { \ + cooperative_groups::this_grid().sync(); \ + float tmp_val = (float)(INIT_VAL); \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = FP32_FUNC( \ + tmp_val, __half2float(mem[y * spatial_size + spatial_index])); \ + } \ + return __float2half(tmp_val); \ + } +CINN_GRID_REDUCE_FP16_MACRO(sum_fp16, cinn_sum_fp32, 0.0f) +CINN_GRID_REDUCE_FP16_MACRO(prod_fp16, cinn_prod_fp32, 1.0f) +CINN_GRID_REDUCE_FP16_MACRO(max_fp16, cinn_max_fp32, -65504.0f) +CINN_GRID_REDUCE_FP16_MACRO(min_fp16, cinn_min_fp32, 65504.0f) __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { __shared__ bool done; @@ -1238,6 +1289,10 @@ C_Status MetaxCompile(void* dev_ptr, src_file << code; src_file.close(); } + // std::cout << "[MetaX] src_file content written to: " << src_path + // << "\n--- BEGIN src_file ---\n" + // << kMacaRuntimeSource << "\n" << code + // << "\n--- END src_file ---" << std::endl; // 2. Resolve compiler binary path const char* maca_path_env = std::getenv("MACA_PATH"); diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 7f19db35e4..5bf4d3dead 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -28,8 +28,11 @@ namespace metax { C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { CUmodule module; CUresult err = cuModuleLoad(&module, path); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path + << ", error=" << err << std::endl; + return C_Status::C_FAILED; + } *mod_out = reinterpret_cast(module); return C_Status::C_SUCCESS; } @@ -47,8 +50,11 @@ C_Status MetaxGetKernelAddress(void* dev_ptr, void** func_out) { CUfunction func; CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name + << " module=" << module_handle << " error=" << err << std::endl; + return C_Status::C_FAILED; + } *func_out = reinterpret_cast(func); return C_Status::C_SUCCESS; } @@ -82,6 +88,62 @@ C_Status MetaxLaunchKernel(void* dev_ptr, return C_Status::C_SUCCESS; } +// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to +// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge) +// to guarantee all thread blocks are co-resident on the GPU, which is +// required by cross-block grid_reduce barriers (__cinn_grid_sync). +C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream) { + int device = 0; + cuCtxGetDevice(&device); + int sm_count = 0; + cuDeviceGetAttribute( + &sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device); + int active_blocks_per_sm = 0; + cuOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks_per_sm, + static_cast(func_ptr), + bx * by * bz, + static_cast(shm)); + // Comment for debug + // std::cerr << "[MetaxLaunchCooperativeKernel]" + // << " grid=(" << gx << "," << gy << "," << gz << ")" + // << " total_blocks=" << (gx * gy * gz) + // << " block_size=" << (bx * by * bz) + // << " shm=" << shm + // << " sm_count=" << sm_count + // << " active_blocks_per_sm=" << active_blocks_per_sm + // << " max_active=" << (sm_count * active_blocks_per_sm) << + // std::endl; + CUresult err = cuLaunchCooperativeKernel(static_cast(func_ptr), + gx, + gy, + gz, + bx, + by, + bz, + shm, + static_cast(stream), + args); + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err + << " grid=(" << gx << "," << gy << "," << gz << ")" + << " block=(" << bx << "," << by << "," << bz << ")" + << " shm=" << shm << std::endl; + return C_Status::C_FAILED; + } + return C_Status::C_SUCCESS; +} + } // namespace metax } // namespace custom_device } // namespace paddle