Skip to content

Adopt Local Map Wrapper for Inner Attention#2557

Open
acisseJZhong wants to merge 15 commits intomainfrom
cp_localmap
Open

Adopt Local Map Wrapper for Inner Attention#2557
acisseJZhong wants to merge 15 commits intomainfrom
cp_localmap

Conversation

@acisseJZhong
Copy link
Contributor

@acisseJZhong acisseJZhong commented Mar 11, 2026

This is a continuation of work in #2480 by @pianpwk

Summary

  • Add _InnerAttentionBase base class to attention.py — overrides __call__
    to wrap nn.Module.__call__ with local_map, converting TP DTensor inputs
    to local tensors before any forward_pre_hook fires, and wrapping
    outputs back to DTensor after all forward_hooks complete. Placements
    and device mesh are inferred from the input DTensors at runtime.
  • Enable TP+CP support for Qwen3 by having LocalMapModule handle the
    TP/CP boundary. Qwen3 requires this because it uses use_local_output=False
    on wq/wk/wv (needed for QK norms with SequenceParallel), producing
    DTensors that CP hooks cannot directly consume.
  • Add integration test for Qwen3 FSDP+TP+CP on 8 GPUs.

Test

NCCL_NVLS_ENABLE=0 MODULE=qwen3 CONFIG=qwen3_debugmodel ./run_train.sh --parallelism.data_parallel_shard_degree 2 --parallelism.tensor_parallel_degree 2 --parallelism.context_parallel_degree 2 --compile.enable  --compile.backend eager 

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 11, 2026
@acisseJZhong acisseJZhong requested a review from pianpwk March 11, 2026 23:59
v, DTensor
), "q, k, v should all be DTensors"
if self._local_map_fn is None:
self._local_map_fn = local_map(
Copy link
Contributor Author

@acisseJZhong acisseJZhong Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin i had this local_map wrapping around CP pre-forward hook and post-forward hook, so that the local map will handle conversion of Dtensor to plain tensor before the CP pre-forward hook fires. If it's the other way around, i will hit assertion error https://github.com/pytorch/pytorch/blame/main/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L1394

not sure if this is the best way; and maybe you already have some plan on rewriting the CP API for full DTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this makes sense. Some comments

  1. Instead of _InnerAttentionLocalMap, we should just implement this logic into Module. This is going to be something we need anyway for MoE. The implementation can be less generic at this point and leave some TODOs that once we integrate the config system with sharding spec, the implementation should be more generic.

  2. We will have to ensure torch.compile doesn't break. So try with --compile.enable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that these inner wrappers are not Module yet. You can just wrap them with Module with Config to be empty at this moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can we inherit Module and call them LocalMapModule?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compile is not working with this super().call approach, Claude tells me that

Why it breaks compile: When dynamo encounters self.inner_attention(xq, xk, xv, ...), it detects that
LocalMapModule has a custom call (not the standard nn.Module.call). Dynamo then tries to trace through
the custom call as a regular Python function. Inside call, the super().call(q, k, v, **kwargs) call
bypasses dynamo's special nn.Module call handling (which normally inlines forward() and properly handles hooks).
This causes FX nodes to be created without proper meta["val"] metadata, leading to the inductor KeyError: 'val'
error.

Not sure how six handles compile with this call override.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to_local and from_local on module boundary is fine, because we can still force a pair of them within clear boundary. Would be good to know if it works with spmd_types, which is a context manager.

Copy link
Member

@xmfan xmfan Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for context managers, we can manually call with __enter__() and __exit__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xmfan there's some limitation due to torch_function, so we can't use enter and exit. This is the current entry point https://fburl.com/code/ypffryqp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xmfan i changed to hook based approach here 26cdceb but compile still fails. P2234099693

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly replacing forward with wrapped forward work, #2621.

Comment on lines +120 to +121
placements = module._placements
mesh = module._device_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may not exist if args[0] is not DTensor?



class VarlenAttentionWrapper(torch.nn.Module):
def _to_local(x: Any) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can define inline? doesn't need to be at root level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mimic six's way D83451817, but moving inline sounds better.

def __init__(self) -> None:
super().__init__()
self.register_forward_pre_hook(
LocalMapModule._pre_hook, with_kwargs=True, prepend=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why prepend=True?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@acisseJZhong Is there a way to use __call__ with torch.compile? I am worried that hook will cause the order issues, which happened many times before. We didn't have a good way to do this because we could not change nn.Module. But since now we have our own Module, I would prefer using call if possible. torch.compile is definitely the key blocker we need to fix.

cc., @xmfan

Copy link
Contributor Author

@acisseJZhong acisseJZhong Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l in fact this is the reason I also prefer to have local map wrapping super.call, and solve any compile issues on that. Otherwise for this forward pre hook, and for the CP post forward hook, I need to have prepend=True; to make sure that the hooks I added can wrap around CP hooks.

I would prefer rolling back to my last commit and fix compile problem. 26cdceb @xmfan

return x


class LocalMapModule(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class LocalMapModule(Module):
class LocalMapAttention(Module):

The implementation is very restricted, in the sense that

  • input placements and output placements must be the same
  • grad placements of inputs must stay the same. From typing perspective, this is only true for Shard.

Taking Flex all-gather based CP as an example

Let's restrict the scope to Attention, since this is not a general local map module.

Please add detailed docstring. (Ask me / claude if anything is not clear)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scope of this LocalMapModule is restricted to Attention because right now we don't have a way to configure local map, which is why this module is put here and everything is fixed.

However, I think this is a good start though to demonstrate how to do module level local map and extend it to a general implementation once our config includes local map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on CP mesh, q is Shard and kv are Replicate

When I specify local_map/to/from_local, q, k, v should all be Shard on TP mesh.
And later on CP mesh, q, k, v should all be Shard and this is because of CP pre hook and post hook? https://github.com/pytorch/pytorch/blob/4d01cdb5b2a633c45471bdaf8d8d544c4bb2572a/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L1396

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, that's right -- if that's the case, can we assert every placement to be Shard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is what I am doing in the local map approach. Let me roll back and add explicitly assertion to Shard for now

self.register_forward_hook(LocalMapModule._post_hook)

@staticmethod
def _pre_hook(module, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give it more meaningful name, e.g. _inputs_to_local, _outputs_from_local


# TODO: cuDNN SDPA backward has a stride mismatch bug with CP.
# Exclude cuDNN until PyTorch fix lands. See https://github.com/pytorch/pytorch/issues/176915.
if attn_backend == "sdpa":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how I feel about this workaround.

If we have to do it, don't do per model. Instead, do it in SDPA module, e.g. maybe by override the pre forward hook and subtract cudnn from self.sdpa_backends

@acisseJZhong acisseJZhong added this to the New Feature, Model, Misc milestone Mar 13, 2026
@acisseJZhong acisseJZhong self-assigned this Mar 13, 2026
@acisseJZhong acisseJZhong moved this from Todo to In Progress in 26H1 TorchTitan Development Mar 13, 2026
@wwwjn wwwjn linked an issue Mar 13, 2026 that may be closed by this pull request
@acisseJZhong acisseJZhong force-pushed the cp_localmap branch 3 times, most recently from 9da4e39 to 844c107 Compare March 13, 2026 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

qwen3 TP + CP bug

5 participants