Adopt Local Map Wrapper for Inner Attention#2557
Conversation
| v, DTensor | ||
| ), "q, k, v should all be DTensors" | ||
| if self._local_map_fn is None: | ||
| self._local_map_fn = local_map( |
There was a problem hiding this comment.
@fegin i had this local_map wrapping around CP pre-forward hook and post-forward hook, so that the local map will handle conversion of Dtensor to plain tensor before the CP pre-forward hook fires. If it's the other way around, i will hit assertion error https://github.com/pytorch/pytorch/blame/main/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L1394
not sure if this is the best way; and maybe you already have some plan on rewriting the CP API for full DTensor
There was a problem hiding this comment.
yes, this makes sense. Some comments
-
Instead of
_InnerAttentionLocalMap, we should just implement this logic intoModule. This is going to be something we need anyway for MoE. The implementation can be less generic at this point and leave some TODOs that once we integrate the config system with sharding spec, the implementation should be more generic. -
We will have to ensure
torch.compiledoesn't break. So try with--compile.enable.
There was a problem hiding this comment.
I realized that these inner wrappers are not Module yet. You can just wrap them with Module with Config to be empty at this moment.
There was a problem hiding this comment.
Or can we inherit Module and call them LocalMapModule?
There was a problem hiding this comment.
Compile is not working with this super().call approach, Claude tells me that
Why it breaks compile: When dynamo encounters self.inner_attention(xq, xk, xv, ...), it detects that
LocalMapModule has a custom call (not the standard nn.Module.call). Dynamo then tries to trace through
the custom call as a regular Python function. Inside call, the super().call(q, k, v, **kwargs) call
bypasses dynamo's special nn.Module call handling (which normally inlines forward() and properly handles hooks).
This causes FX nodes to be created without proper meta["val"] metadata, leading to the inductor KeyError: 'val'
error.
Not sure how six handles compile with this call override.
There was a problem hiding this comment.
I think to_local and from_local on module boundary is fine, because we can still force a pair of them within clear boundary. Would be good to know if it works with spmd_types, which is a context manager.
There was a problem hiding this comment.
for context managers, we can manually call with __enter__() and __exit__()
There was a problem hiding this comment.
@xmfan there's some limitation due to torch_function, so we can't use enter and exit. This is the current entry point https://fburl.com/code/ypffryqp
There was a problem hiding this comment.
Directly replacing forward with wrapped forward work, #2621.
| placements = module._placements | ||
| mesh = module._device_mesh |
There was a problem hiding this comment.
may not exist if args[0] is not DTensor?
|
|
||
|
|
||
| class VarlenAttentionWrapper(torch.nn.Module): | ||
| def _to_local(x: Any) -> Any: |
There was a problem hiding this comment.
can define inline? doesn't need to be at root level.
There was a problem hiding this comment.
I mimic six's way D83451817, but moving inline sounds better.
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.register_forward_pre_hook( | ||
| LocalMapModule._pre_hook, with_kwargs=True, prepend=True |
There was a problem hiding this comment.
any reason why prepend=True?
There was a problem hiding this comment.
@acisseJZhong Is there a way to use __call__ with torch.compile? I am worried that hook will cause the order issues, which happened many times before. We didn't have a good way to do this because we could not change nn.Module. But since now we have our own Module, I would prefer using call if possible. torch.compile is definitely the key blocker we need to fix.
cc., @xmfan
There was a problem hiding this comment.
@tianyu-l in fact this is the reason I also prefer to have local map wrapping super.call, and solve any compile issues on that. Otherwise for this forward pre hook, and for the CP post forward hook, I need to have prepend=True; to make sure that the hooks I added can wrap around CP hooks.
I would prefer rolling back to my last commit and fix compile problem. 26cdceb @xmfan
| return x | ||
|
|
||
|
|
||
| class LocalMapModule(Module): |
There was a problem hiding this comment.
| class LocalMapModule(Module): | |
| class LocalMapAttention(Module): |
The implementation is very restricted, in the sense that
- input placements and output placements must be the same
- grad placements of inputs must stay the same. From typing perspective, this is only true for
Shard.
Taking Flex all-gather based CP as an example
- in TP mesh, things are Shard on head dim. This is fine.
- on CP mesh, q is Shard and kv are Replicate. The reason you can assume kv grad are also Replicate is because CP API internally does this reduce-scatter for us https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py#L37, which is fine for now
Let's restrict the scope to Attention, since this is not a general local map module.
Please add detailed docstring. (Ask me / claude if anything is not clear)
There was a problem hiding this comment.
The scope of this LocalMapModule is restricted to Attention because right now we don't have a way to configure local map, which is why this module is put here and everything is fixed.
However, I think this is a good start though to demonstrate how to do module level local map and extend it to a general implementation once our config includes local map.
There was a problem hiding this comment.
on CP mesh, q is Shard and kv are Replicate
When I specify local_map/to/from_local, q, k, v should all be Shard on TP mesh.
And later on CP mesh, q, k, v should all be Shard and this is because of CP pre hook and post hook? https://github.com/pytorch/pytorch/blob/4d01cdb5b2a633c45471bdaf8d8d544c4bb2572a/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L1396
There was a problem hiding this comment.
oh, that's right -- if that's the case, can we assert every placement to be Shard?
There was a problem hiding this comment.
yes, this is what I am doing in the local map approach. Let me roll back and add explicitly assertion to Shard for now
| self.register_forward_hook(LocalMapModule._post_hook) | ||
|
|
||
| @staticmethod | ||
| def _pre_hook(module, args, kwargs): |
There was a problem hiding this comment.
give it more meaningful name, e.g. _inputs_to_local, _outputs_from_local
|
|
||
| # TODO: cuDNN SDPA backward has a stride mismatch bug with CP. | ||
| # Exclude cuDNN until PyTorch fix lands. See https://github.com/pytorch/pytorch/issues/176915. | ||
| if attn_backend == "sdpa": |
There was a problem hiding this comment.
I don't know how I feel about this workaround.
If we have to do it, don't do per model. Instead, do it in SDPA module, e.g. maybe by override the pre forward hook and subtract cudnn from self.sdpa_backends
9da4e39 to
844c107
Compare
844c107 to
2d06f2e
Compare
This is a continuation of work in #2480 by @pianpwk
Summary
__call__to wrap
nn.Module.__call__with local_map, converting TP DTensor inputsto local tensors before any
forward_pre_hookfires, and wrappingoutputs back to DTensor after all
forward_hookscomplete. Placementsand device mesh are inferred from the input DTensors at runtime.
LocalMapModulehandle theTP/CP boundary. Qwen3 requires this because it uses use_local_output=False
on wq/wk/wv (needed for QK norms with SequenceParallel), producing
DTensors that CP hooks cannot directly consume.
Test