Skip to content

[WIP] Feat/muxi device registry#145

Open
kilinchange wants to merge 11 commits intomasterfrom
feat/muxi_device_registry
Open

[WIP] Feat/muxi device registry#145
kilinchange wants to merge 11 commits intomasterfrom
feat/muxi_device_registry

Conversation

@kilinchange
Copy link
Copy Markdown
Collaborator

@kilinchange kilinchange commented Apr 20, 2026

沐曦适配,用于对比显示代码修改情况的临时 pr。

Introduce a MACA (MetaX 沐曦) backend plugged into the DeviceGuardImpl /
kernel dispatcher framework, targeting the minimal kernel set needed to
validate single-card fp32 training (e.g. mnist) end-to-end:

- Build system: USE_MACA / USE_MCCL options, mxcc toolchain override,
  mxomp linkage under USE_OMP, .maca kernel library with -x maca, and
  backend-exclusive SRC filtering so non-target backends are not pulled in.
- Device enum: add Device::DeviceType::kMACA (kCount bumped to 3),
  IsMACA(), and a three-way ToString() switch.
- common/maca: MACA_CHECK / MCBLAS_CHECK / MCCL_CHECK macros and the
  kernel_helper.cuh template library (Cast/Neg/Sin/Pow/Add/Sub/Mul/Div/
  Max/Min/Fma/fastAtomicAdd) plus a cub_compat.cuh shim pinning CubSumOp/
  CubMaxOp/CubMinOp to the pre-2.8 CUB API that MACA ships.
- core/runtime/maca: MacaStream / MacaEvent / MacaBlasHandle derived from
  core::Stream / Event / BlasHandle, and MacaGuardImpl mirroring
  CudaGuardImpl (mcInit(0) in ctor, call_once'd default stream/handle
  caches, full stream/event/sync/blas/memory surface). Mempool watermark
  hooks are stubs pending SDK verification.
- datatype.h / tensor.cc / nn/init.cc: add USE_MACA branches to map
  kBFLOAT16 / kFLOAT16 to __maca_bfloat16 / __half, specialize the
  is_floating_point_ext / is_arithmetic_ext / LargerType traits, route
  Fill casts through float under real device backends to dodge the
  ambiguous __half(int) constructor on MACA, and wire Arange for bf16/fp16.
- kernels/maca: mechanically port the minimal 5-kernel slice
  (elementwise, linear, fill, no_op, accumulate_grad) from their .cu
  counterparts, switching blas/stream acquisition to the new
  GetDeviceGuardImpl()->GetBlasHandle()/GetStream() idiom.

The MCCL collective backend and the remaining 15 kernels (which are
required for gpt2 / DDP) will land in a follow-up commit.
Complete the MACA backend by adding the MCCL-based collective
implementation and the rest of the kernel library, enabling multi-card
training (DDP) and larger models such as gpt2.

- core/ccl/maca: McclComm / McclUniqueId wrappers around mcclComm_t /
  mcclUniqueId, with Size/Data/Load tied to sizeof(mcclUniqueId) so that
  the existing backend-agnostic WriteUniqueIdFile / ReadUniqueIdFile
  unique-id exchange path works unchanged. McclImpl mirrors NcclImpl
  with kMcclDtypeMap / kMcclReduceOpMap and routes every collective
  through mcStream_t via dynamic_cast<MacaStream *>. Registered via
  INFINI_TRAIN_REGISTER_CCL_IMPL(kMACA, McclImpl), so ProcessGroup
  backed by Device::DeviceType::kMACA transparently picks up MCCL
  without any ProcessGroupMCCL subclass.
- kernels/maca: mechanically port the remaining 15 kernels (cast,
  comm, concat, cross_entropy, embedding, gather, layernorm, outer,
  reduction, slice, softmax, split, stack, transform,
  vocab_parallel_cross_entropy) from their .cu counterparts, including
  the cub_compat path for cross_entropy/softmax/reduction, mcblas
  GEMM / GemmEx calls in outer, and __maca_bfloat16 / __half typing
  throughout.
CMakeLists.txt:
- Pre-set HAVE_MODE_T/HAVE_SSIZE_T and their sentinel variables
  (HAVE_HAVE_MODE_T/HAVE_HAVE_SSIZE_T) before add_subdirectory(glog),
  since mxcc cmake feature-detection probes cannot find standard POSIX
  headers; without the sentinels check_type_size re-runs and overwrites
  the pre-set values, causing glog to emit conflicting fallback typedefs
- Add BUILD_TESTING=OFF to skip glog unit tests (-fPIE unsupported by mxcc)
- Add BUILD_SHARED_LIBS=OFF to build glog as a static library; mxcc
  defaults to hidden symbol visibility, making libglog.so export nothing

datatype.h:
- Add is_bfloat16<T> and is_fp16<T> type traits with USE_CUDA/USE_MACA
  specializations, needed by common_cpu.h Cast and init.cc ARANGE_CASE

common/cpu/common_cpu.h:
- Route fp16/bf16 destinations through float in Cast<T>(), avoiding
  ambiguous integer→__half/__maca_bfloat16 conversion on MACA

kernels/maca/{stack,concat,slice,transform,elementwise,split,gather}.maca:
- Add reinterpret_cast<void **> to all mcMallocAsync(&ptr, ...) calls;
  MACA's mcMallocAsync requires void** but typed pointers were passed
- Fix mcDevAttrMultiProcessorCount → mcDeviceAttributeMultiProcessorCount
  in elementwise.maca (correct MACA enum name)

optimizer.cc:
- Change Fill<T>(0) → Fill<T>(0.f) for Adam m/v initialization;
  __half(0) is ambiguous on MACA (only float/double ctors available)

nn/init.cc:
- Replace std::iota + static_cast<TYPE>(start) in ARANGE_CASE with an
  explicit loop via static_cast<float> to avoid ambiguous integer→fp16/
  bf16 conversion for kBFLOAT16/kFLOAT16 cases

example/gpt2/main.cc:
- Add kDeviceMACA constant, update --device validator to accept "maca",
  and add Device::DeviceType::kMACA branch in device selection
Port MACA backend to master's backend-explicit dtype registration:

- Add src/core/runtime/maca/maca_dispatch.h: register __half / __maca_bfloat16
  via BackendTypeMap<kMACA, kFLOAT16/kBFLOAT16>, declare
  INFINI_REGISTER_STANDARD_BACKEND_TYPES(kMACA), and expose DispatchMacaFunc /
  MacaTypeMap mirroring the CUDA side.
- Replace every DispatchFunc<...>/WidestType_t/DataTypeMap_v site across 18
  MACA kernels with DispatchMacaFunc / PromoteDataTypes.
- Replace Tensor::Fill<T>(0) template calls with Fill(0) to match the new
  Scalar-taking Tensor::Fill API.
- fill.maca: route Scalar::to<T> through common::maca::Cast<T>(scalar.to<float>())
  for __maca_bfloat16/__half to avoid ambiguous static_cast from integer
  Scalar kinds (see scalar.h TODO).
The MACA runtime auto-cross-maps mcMalloc'd buffers as P2P-readonly
between sibling devices in the same process, so multi-thread DDP
(nthread>=4) crashed ~70% of the time during model upload with
"Writing to readonly page" on a 64MB buffer whose owner node was
missing from the mapped peer list.

llama3/main.cc: defer ProcessGroup creation until after model->To,
serialize model->To across DP threads with a process-wide mutex,
and barrier between upload and PG init so MCCL P2P registration
never overlaps with peer-thread allocations. Compute in-group
ranks via std::find on the rank topology so LoadFromLLMC still
sees the correct tp_rank before any PG exists.

reducer.cc: switch FinalizeBackward to host-blocking
work->Synchronize() so the CPU bucket-rebuild can't race past an
in-flight AllReduce.

maca_guard_impl.cc: setenv MACA_LAUNCH_BLOCKING=1 before mcInit(0)
in the ctor (setenv from main is too late since mcInit runs during
static init), and serialize mcMalloc/mcFree behind a global mutex.

llama3/gpt2 main.cc: std::_Exit(0) after training when device==maca
&& nthread_per_process>1 to bypass the broken static-destruction
chain — ProcessGroupMCCL intentionally skips mcclCommDestroy, and
the leaked MCCL/P2P buffers otherwise trip mxkwUnmapMemoryToGPU
and SIGABRT during teardown.

Validated: 20/20 passes on
  ./llama3 --device maca --nthread_per_process=8 --num_iteration=10
           --batch_size=10 --total_batch_size=5120
Single-card path (nthread_per_process=1) still passes.
- Move MACA/MCCL P2P_DISABLE setenv into MacaGuardImpl ctor and parse
  --tensor_parallel from /proc/self/cmdline, so both flags land before
  mcInit(0) (setenv from main() was too late at static init).
- Also disable MCCL_P2P_DISABLE when TP>1: MACA_P2P_DISABLE alone still
  lets MCCL establish its own P2P buffers, which deadlocks multi-PG
  init on TP+SP / TP+SP+PP+VPP.
- gpt2 main: defer ProcessGroup creation until after model->To(device),
  serialize the upload under a mutex + barrier across DP threads. MCCL
  init otherwise leaves stale read-only P2P mappings in the VA ranges
  mcMalloc later returns, racing with concurrent model uploads.
- Drop the now-redundant setenv blocks from gpt2/llama3 main().
@kilinchange kilinchange force-pushed the feat/muxi_device_registry branch from 321a98b to 1f10a97 Compare April 22, 2026 10:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant