Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr,
int shm,
void* stream);

// Launches a cooperative kernel function (grid-level sync)
extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx,
int gy,
int gz,
int bx,
int by,
int bz,
int shm,
void* stream);

// --- From passes/pass_manager.cc ---
// Applies custom graph optimization passes
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);
Expand Down Expand Up @@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) {
metax_cinn_impl.module_unload = MetaxModuleUnload;
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel;

// 6. Register Compilation Strategy interface
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
Expand Down
67 changes: 61 additions & 6 deletions backends/metax_gpu/cinn/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace metax {
// ============================================================
static const char* kMacaRuntimeSource = R"MACA_SOURCE(
#pragma once
#include <cooperative_groups.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

Expand Down Expand Up @@ -780,12 +781,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64(
CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value);
}

// ===============================================================
// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync())
// Uses a sense-reversing barrier so it works correctly when called
// multiple times within the same kernel.
// REQUIREMENT: all thread blocks must be co-resident on the GPU.
// ===============================================================
__device__ unsigned int __cinn_grid_barrier_count[8192];
__device__ unsigned int __cinn_grid_barrier_flag[8192];

__device__ inline void __cinn_grid_sync() {
__threadfence();
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
unsigned int expected =
atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u);
unsigned int arrived =
atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u;
if (arrived == (unsigned int)gridDim.y) {
atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u);
__threadfence();
atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected);
__threadfence();
} else {
while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) ==
expected) {
}
}
}
__syncthreads();
}

#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \
DTYPE tmp_val = init_value; \
for (int y = 0; y < gridDim.y; y++) { \
tmp_val = \
cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
} \
cooperative_groups::this_grid().sync(); \
DTYPE tmp_val = init_value; \
for (int y = 0; y < gridDim.y; y++) { \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
} \
return tmp_val;

#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \
Expand All @@ -799,7 +831,26 @@ EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO)
// FP16 grid reduce: accumulate in FP32 to avoid precision loss when summing
// multiple FP16 block-level partial sums. Each partial sum can have magnitude
// O(block_size * input_scale), and accumulating N such values in FP16 incurs
// error proportional to N * magnitude * eps_fp16. Using FP32 for the inter-
// block accumulation step keeps the error at FP16 quantization level only.
#define CINN_GRID_REDUCE_FP16_MACRO(FP16_TYPE, FP32_FUNC, INIT_VAL) \
__device__ inline float16 cinn_grid_reduce_##FP16_TYPE( \
const float16 *mem, int spatial_size, int spatial_index) { \
cooperative_groups::this_grid().sync(); \
float tmp_val = (float)(INIT_VAL); \
for (int y = 0; y < gridDim.y; y++) { \
tmp_val = FP32_FUNC( \
tmp_val, __half2float(mem[y * spatial_size + spatial_index])); \
} \
return __float2half(tmp_val); \
}
CINN_GRID_REDUCE_FP16_MACRO(sum_fp16, cinn_sum_fp32, 0.0f)
CINN_GRID_REDUCE_FP16_MACRO(prod_fp16, cinn_prod_fp32, 1.0f)
CINN_GRID_REDUCE_FP16_MACRO(max_fp16, cinn_max_fp32, -65504.0f)
CINN_GRID_REDUCE_FP16_MACRO(min_fp16, cinn_min_fp32, 65504.0f)

__device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
__shared__ bool done;
Expand Down Expand Up @@ -1238,6 +1289,10 @@ C_Status MetaxCompile(void* dev_ptr,
src_file << code;
src_file.close();
}
// std::cout << "[MetaX] src_file content written to: " << src_path
// << "\n--- BEGIN src_file ---\n"
// << kMacaRuntimeSource << "\n" << code
// << "\n--- END src_file ---" << std::endl;

// 2. Resolve compiler binary path
const char* maca_path_env = std::getenv("MACA_PATH");
Expand Down
70 changes: 66 additions & 4 deletions backends/metax_gpu/cinn/runtime/cinn_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ namespace metax {
C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) {
CUmodule module;
CUresult err = cuModuleLoad(&module, path);
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;

if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path
<< ", error=" << err << std::endl;
return C_Status::C_FAILED;
}
*mod_out = reinterpret_cast<void*>(module);
return C_Status::C_SUCCESS;
}
Expand All @@ -47,8 +50,11 @@ C_Status MetaxGetKernelAddress(void* dev_ptr,
void** func_out) {
CUfunction func;
CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name);
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;

if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name
<< " module=" << module_handle << " error=" << err << std::endl;
return C_Status::C_FAILED;
}
*func_out = reinterpret_cast<void*>(func);
return C_Status::C_SUCCESS;
}
Expand Down Expand Up @@ -82,6 +88,62 @@ C_Status MetaxLaunchKernel(void* dev_ptr,
return C_Status::C_SUCCESS;
}

// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to
// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge)
// to guarantee all thread blocks are co-resident on the GPU, which is
// required by cross-block grid_reduce barriers (__cinn_grid_sync).
C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx,
int gy,
int gz,
int bx,
int by,
int bz,
int shm,
void* stream) {
int device = 0;
cuCtxGetDevice(&device);
int sm_count = 0;
cuDeviceGetAttribute(
&sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device);
int active_blocks_per_sm = 0;
cuOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks_per_sm,
static_cast<CUfunction>(func_ptr),
bx * by * bz,
static_cast<size_t>(shm));
// Comment for debug
// std::cerr << "[MetaxLaunchCooperativeKernel]"
// << " grid=(" << gx << "," << gy << "," << gz << ")"
// << " total_blocks=" << (gx * gy * gz)
// << " block_size=" << (bx * by * bz)
// << " shm=" << shm
// << " sm_count=" << sm_count
// << " active_blocks_per_sm=" << active_blocks_per_sm
// << " max_active=" << (sm_count * active_blocks_per_sm) <<
// std::endl;
CUresult err = cuLaunchCooperativeKernel(static_cast<CUfunction>(func_ptr),
gx,
gy,
gz,
bx,
by,
bz,
shm,
static_cast<CUstream>(stream),
args);
if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err
<< " grid=(" << gx << "," << gy << "," << gz << ")"
<< " block=(" << bx << "," << by << "," << bz << ")"
<< " shm=" << shm << std::endl;
return C_Status::C_FAILED;
}
return C_Status::C_SUCCESS;
}

} // namespace metax
} // namespace custom_device
} // namespace paddle
Loading