Concise guidance for AI coding agents working in this repo. Keep answers specific to MPLang and its current code, not generic.
- MPLang is a Python SPMD framework to orchestrate multi-party/multi-device programs from a single controller.
- Core idea: trace Python functions into an IR, then interpret/execute across parties/devices with explicit security domains.
- Core SPMD/IR:
mplang/core/primitive.py(defines@primitive/function, control-flow likecond,while_loop,peval),tracer.py(TraceContext/TraceVar),interp.py(InterpContext/InterpVar,apply),mpobject.py,mptype.py,dtype.py,tensor.py,table.py,pfunc.py,mask.py.
- Expression AST:
mplang/core/expr/(Expr nodes used by tracing; see imports inprimitive.py). - Runtime:
mplang/runtime/(simulation.pyfor local multi-threaded runs,driver.py,server.py,client.py,communicator.py). - Frontends/Backends:
mplang/ops/*(basic, jax_cc, etc.) andmplang/kernels/*(std, spu, phe, sql_duckdb, stablehlo). - Devices API:
mplang/device.py(device placement/transforms). - Low-level party API:
mplang/simp/*(MPI-style ops, random, smpc) — used when you need rank-level control. - Public API surface:
mplang/__init__.py(re-exports:function,compile,evaluate,fetch,Simulator, etc.).
- Contexts are king: code must work under both
TraceContextandInterpContext.- Inside
@primitivefunctions, usecur_ctx()and do NOT cross-shareTraceVar/InterpVar; capture into the current context (see_switch_ctxinprimitive.py).
- Inside
- New primitives: decorate with
@primitive(orfunctionalias). For backend-evaluable ops, route throughmplang/ops/basic.pyand callpeval(pfunc, eval_args, rmask?)to create IR; return the unflattened tree. - Masks:
Maskmodels which parties hold/execute values.set_maskenforces runtime execution mask; static/dynamic pmask rules are documented in its docstring. - Control flow: use
cond,while_loop,peval,ConvExpr/ShflExprvia helpers inprimitive.pyinstead of ad-hoc Python control flow inside traced code. - Devices: prefer
mplang.devicefor placement and@mplang.functionfor graph capture; avoid leaking ranks into device-level code (rank-level ops live inmplang/simp). - Types/Style: Python 3.11+, type hints everywhere. Use
rufffor lint/format andmypyonmplang/before PRs.
- Install:
uv sync --group devthenuv pip install -e . - Tests:
uv run pytest(parallel:-n auto), focused: e.g.tests/core/test_primitive.py. - Lint/Format:
uv run ruff check . --fixanduv run ruff format . - Types:
uv run mypy mplang/ - Tutorials:
uv run tutorials/0_basic.py,1_condition.py,2_whileloop.py, etc.
- Protos live under
protos/(root) and generate into the repo (e.g.,mplang/protos/v1alpha1/). - Workflow:
buf format -w,buf lint,buf generate,buf breaking --against '.git#branch=main',buf dep updatewhen deps change. Commit generated files with the proto changes.
- How primitives build IR:
mplang/core/primitive.py(seepeval,prand,constant,cond). - Context behavior:
mplang/core/tracer.py,mplang/core/interp.py, and_switch_ctxinprimitive.py. - Runtime simulation and E2E: mplang/runtime/simulation.py, tutorials/*.py, tests in tests/** (e.g., tests/kernels/test_basic.py).
- Do add new high-level APIs by composing existing primitives instead of bypassing trace/IR.
- Do keep device-level code rank-agnostic; if you need ranks, implement in
mplang/simp. - Don't move
TraceVar/InterpVaracross contexts directly; always capture/switch via current context. - Don't introduce hidden global state for masks/devices; pass masks explicitly where supported (e.g.,
peval(..., rmask=...)).
Location: mplang2/. Status: ~95% feature parity with v1 (2025-12-01).
- Dialect extensibility: v1 treats
simpas first-class dialect; other ops (tensor,table,spu, etc.) are second-class, must embed in@primitive, can't be traced/extended independently. v2 makes all dialects equal citizens. - Type system: v1 only has
TensorType/TableType; other types (ciphertexts, keys) are "simulated" via Tensor, losing compile-time checks. v2 has extensibleScalarType,TensorType,TableType,VectorType,SSType,CustomType.
Neither v1 nor v2 focuses on optimization passes yet; v2's goals are typing and dialect extensibility.
mplang2/edsl/: typing.py (unified types), graph.py (Op List + SSA IR), primitive.py, tracer.py, interpreter.py, jit.py.mplang2/dialects/: simp, tensor, table, spu, tee, bfv, phe, dtypes.mplang2/backends/: simp_simulator, simp_http_driver, simp_http_worker, cli, *_impl.mplang2/libs/device.py: Device API (v1-compatible).
- Import:
import mplang2 as mp - JAX on PPU:
mp.device("P0").jax(fn)(explicit frontend via .jax property). - Constants:
mp.put("P0", 42)instead ofmp.device("P0")(lambda: 42)(). - IR inspection:
mp.jit(fn).compiler_ir().
✅ device/put, Simulator/Driver, evaluate/fetch, jit/trace, TensorType/TableType/ScalarType, dialects (simp/tensor/table/spu/tee/bfv/phe), table I/O, CLI. ❌ analysis/diagrams not yet ported.
Before significant changes, write design/<feature>.md with: Summary, Motivation, API Surface, Implementation (modules, TraceContext vs InterpContext interaction), Alternatives, Migration/Compat, Test Plan. See design/architecture.md as example.
Questions or unclear areas? Pin them to specific files/APIs and I'll refine this doc.