-
Notifications
You must be signed in to change notification settings - Fork 33
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexer (LI) patches #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
0hujun
wants to merge
16
commits into
modelscope:main
Choose a base branch
from
0hujun:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
98e69cd
fix: Npu Group MatMul op patchs only in EP
0hujun 1992ca0
Update src/twinkle/kernel/monkey_patch_npu.py
0hujun 598c5ab
Update src/twinkle/kernel/monkey_patch_npu.py
0hujun f7dafe5
Merge branch 'modelscope:main' into main
0hujun c6590ce
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 7d05df5
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 1e770a4
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 451ef16
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 0a6447b
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 884a78c
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 9b1e26c
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun aa3caea
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 491d562
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 947a94e
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 74b2cdb
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun 57adc45
feat: add DeepSeek-V4 NPU Sparse Attention (SAS) and Lightning Indexe…
0hujun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # DeepSeek-V4 NPU Sparse Attention (SAS) / Lightning Indexer (LI) | ||
|
|
||
| Twinkle 提供的 DeepSeek-V4 NPU 加速 patch,通过 monkey-patch 方式替换 transformers 中的注意力计算和索引器实现,无需修改 transformers 源码。 | ||
|
|
||
| ## 功能说明 | ||
|
|
||
| ### SAS (Sparse Attention Shared-KV) | ||
|
|
||
| 替换 `DeepseekV4Attention.forward` 中的标准注意力计算,使用 mindspeed 提供的融合稀疏注意力核 `SparseAttnSharedKV`,支持三种注意力层类型: | ||
|
|
||
| - **Sliding Attention**: 纯滑动窗口注意力 | ||
| - **CSA (Compressed Sparse Attention)**: 压缩稀疏注意力,使用 Lightning Indexer 选择 top-k 压缩条目 | ||
| - **HCA (Heavily Compressed Attention)**: 高度压缩注意力,所有压缩条目可见 | ||
|
|
||
| ### LI (Lightning Indexer) | ||
|
|
||
| 替换 `DeepseekV4Indexer.forward` 中的 torch 实现,使用 mindspeed 提供的 `npu_lightning_indexer` 加速 top-k 索引选择。 | ||
|
|
||
| **注意**: 当前版本 SAS 和 LI 不能同时启用。 | ||
|
|
||
| ## 依赖 | ||
|
|
||
| - **[ops-transformer](https://gitcode.com/cann/ops-transformer)**: 提供 NPU 算子实现,需要编译安装 | ||
| - **[mindspeed](https://gitcode.com/Ascend/MindSpeed)**: 提供 NPU 算子调用实现,需要使用git clone下载mindspeed并切换到master分支进行手动安装 | ||
| - `mindspeed.ops.npu_sparse_attn_shared_kv.SparseAttnSharedKV` (SAS) | ||
| - `mindspeed.ops.npu_lightning_indexer` (LI) | ||
| - **transformers**: 需包含 DeepSeek-V4 模型支持 | ||
| - **torch_npu**: Ascend NPU 运行时 | ||
|
|
||
| ## 环境变量 | ||
|
|
||
| | 变量 | 默认值 | 说明 | | ||
| |------|--------|------| | ||
| | `TWINKLE_NPU_DSV4_SAS` | `0` | 启用 SAS patch | | ||
| | `TWINKLE_NPU_DSV4_LI` | `0` | 启用 LI patch | | ||
|
|
||
| **约束**: `TWINKLE_NPU_DSV4_SAS` 和 `TWINKLE_NPU_DSV4_LI` 不能同时设置为 `1`。 | ||
|
|
||
| ## 使用示例 | ||
|
|
||
| ### 镜像(可选) | ||
| ```shell | ||
| #A3 | ||
| docker pull swr.cn-southwest-2.myhuaweicloud.com/ascend-sact/twinkle-npu:v4 | ||
| ``` | ||
|
|
||
| ### 启用 SAS | ||
|
|
||
| ```bash | ||
| export TWINKLE_NPU_DSV4_SAS=1 | ||
| torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py | ||
| ``` | ||
|
|
||
| ### 启用 LI | ||
|
|
||
| ```bash | ||
| export TWINKLE_NPU_DSV4_LI=1 | ||
|
|
||
| torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py | ||
| ``` | ||
|
|
||
| ### 完整示例脚本 (ds16_sas.sh) | ||
|
|
||
| ```bash | ||
| #!/bin/bash | ||
| export GLOO_SOCKET_IFNAME="enp162s0f0" | ||
| export HCCL_SOCKET_IFNAME="enp162s0f0" | ||
| export HCCL_CONNECT_TIMEOUT=7200 | ||
| export HCCL_EXEC_TIMEOUT=7200 | ||
| export ACL_DEVICE_SYNC_TIMEOUT=7200 | ||
| export HCCL_IF_BASE_PORT=30000 | ||
| export BATCH_SIZE=8 | ||
| export MAX_STEPS=10 | ||
| export GRADIENT_CHECKPOINTING=1 | ||
| export USE_EP=1 | ||
|
|
||
| # 启用 twinkle SAS patch | ||
| export TWINKLE_NPU_DSV4_SAS=1 | ||
|
|
||
| export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 | ||
| source /usr/local/Ascend/cann/opp/vendors/custom_transformer/bin/set_env.bash | ||
| torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py | ||
| ``` | ||
|
|
||
| ## 实现原理 | ||
|
|
||
| Patch 在 `apply_npu_patch()` 阶段自动应用(位于 EP sharding 之后、FSDP wrap 之前),通过以下方式替换原始实现: | ||
|
|
||
| 1. **Compressor patch**: 包装 `DeepseekV4HCACompressor` 和 `DeepseekV4CSACompressor` 的 `forward` 方法,确保返回 3-tuple `(compressed_kv, block_bias, top_k_indices)` | ||
| 2. **Attention patch**: 替换 `DeepseekV4Attention.forward`,调用 `SparseAttnSharedKV.apply()` 替代标准注意力 dispatch | ||
| 3. **Indexer patch**: 替换 `DeepseekV4Indexer.forward`,调用 `mindspeed.ops.npu_lightning_indexer` 替代 torch 实现 | ||
|
|
||
| 所有 patch 均包含 `ImportError` fallback,当 mindspeed 不可用时自动回退到原始实现。 | ||
|
|
||
| ## 验证 | ||
|
|
||
| 运行测试后,检查日志中是否出现: | ||
|
|
||
| ``` | ||
| [NPU] [DSV4-SAS] Twinkle sparse attention active (layer_type=..., cmp_ratio=..., topk=...) | ||
| ``` | ||
|
|
||
| 或 | ||
|
|
||
| ``` | ||
| [NPU] [DSV4-LI] Twinkle lightning indexer active (sparse_count=..., cmp_ratio=...) | ||
| ``` | ||
|
|
||
| ## 相关文件 | ||
|
|
||
| - `src/twinkle/kernel/deepseek_v4_npu.py`: Patch 核心实现 | ||
| - `src/twinkle/kernel/monkey_patch_npu.py`: Patch 注册和环境变量控制 |
123 changes: 123 additions & 0 deletions
123
cookbook/transformers/deepseek_v4_patch/ep_fsdp2_lora_deepseek_v4_npu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| import os | ||
| import twinkle | ||
| from peft import LoraConfig | ||
| from transformers import AutoConfig | ||
| from twinkle import DeviceMesh, Platform, get_device_placement, get_logger | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.kernel import apply_npu_patch | ||
| from twinkle.model import TransformersModel | ||
| from twinkle.preprocessor import SelfCognitionProcessor | ||
|
|
||
|
|
||
| logger = get_logger() | ||
| MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') | ||
| DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') | ||
| TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') | ||
| OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') | ||
|
|
||
| MAX_LENGTH = int(os.environ.get('MAX_LENGTH', '4096')) | ||
| BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '32')) | ||
| GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2')) | ||
| LOG_INTERVAL = GRAD_ACCUM_STEPS | ||
| LR = float(os.environ.get('LR', '1e-5')) | ||
| MAX_STEPS = int(os.environ.get('MAX_STEPS', '0')) | ||
| SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50')) | ||
| USE_LORA = os.environ.get('USE_LORA', '1') == '1' | ||
| MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) | ||
| IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1' | ||
| GRADIENT_CHECKPOINTING = os.environ.get('GRADIENT_CHECKPOINTING', '1') == '1' | ||
| RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' | ||
| LORA_TARGET_MODULES = os.environ.get( | ||
| 'LORA_TARGET_MODULES', | ||
| 'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj', | ||
| ) | ||
| USE_EP = os.environ.get('USE_EP', '0') == '1' | ||
| ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') | ||
| EP_SIZE = BATCH_SIZE if USE_EP else 1 | ||
| device_mesh = DeviceMesh.from_sizes( | ||
| fsdp_size=BATCH_SIZE, | ||
| dp_size=1, | ||
| ep_size=EP_SIZE, | ||
| device_type=Platform.get_platform().device_prefix(), | ||
| ) | ||
|
|
||
| twinkle.initialize(mode='local', global_device_mesh=device_mesh) | ||
|
|
||
|
|
||
| def create_dataset(data_slice=None): | ||
| dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) | ||
| dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) | ||
| dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) | ||
| return dataset | ||
|
|
||
| def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): | ||
| return model.save( | ||
| name=checkpoint_name, | ||
| output_dir=OUTPUT_DIR, | ||
| adapter_name=ADAPTER_NAME, | ||
| save_optimizer=True, | ||
| consumed_train_samples=dataloader.get_state()['consumed_train_samples'], | ||
| ) | ||
|
|
||
| def train(): | ||
| dataset = create_dataset() | ||
| dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) | ||
|
|
||
| config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
| if hasattr(config, 'use_cache'): | ||
| config.use_cache = False | ||
| model = TransformersModel( | ||
| model_id=MODEL_ID, | ||
| config=config, | ||
| device_mesh=device_mesh, | ||
| strategy="native_fsdp", | ||
| memory_efficient_init=True, | ||
| ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, | ||
| fsdp_config={ | ||
| 'reshard_after_forward': RESHARD_AFTER_FORWARD, | ||
| 'expert_parallel': { | ||
| 'enabled': USE_EP, | ||
| 'router_dtype': 'fp32', | ||
| 'keep_router_logits': False, | ||
| } | ||
| }, | ||
| ) | ||
|
|
||
| apply_npu_patch(model) | ||
|
|
||
| if USE_LORA: | ||
| lora_target_modules = [name.strip() for name in LORA_TARGET_MODULES.split(',') if name.strip()] | ||
| lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules) | ||
| model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS) | ||
|
|
||
| if not GRADIENT_CHECKPOINTING: | ||
| model.model.gradient_checkpointing_disable() | ||
|
|
||
| model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME) | ||
| model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME) | ||
| model.set_lr_scheduler( | ||
| scheduler_cls='CosineWarmupScheduler', | ||
| num_warmup_steps=1, | ||
| num_training_steps=len(dataloader), | ||
| adapter_name=ADAPTER_NAME, | ||
| ) | ||
| optimizer_group = model.optimizer_group[ADAPTER_NAME] | ||
| for batch in dataloader: | ||
| if callable(batch): | ||
| batch = batch() | ||
| model.forward_backward(inputs=batch) | ||
| model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) | ||
| cur_step = optimizer_group.cur_step | ||
| if cur_step > 0 and cur_step % LOG_INTERVAL == 0: | ||
| metric = model.calculate_metric(is_training=True) | ||
| if callable(metric): | ||
| metric = metric() | ||
| logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') | ||
|
|
||
| final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) | ||
|
0hujun marked this conversation as resolved.
|
||
| logger.info(f'Saved final adapter to {final_checkpoint}') | ||
|
|
||
| if __name__ == '__main__': | ||
| train() | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.