Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions cookbook/transformers/deepseek_v4_patch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# DeepSeek-V4 NPU Sparse Attention (SAS) / Lightning Indexer (LI)

Twinkle 提供的 DeepSeek-V4 NPU 加速 patch,通过 monkey-patch 方式替换 transformers 中的注意力计算和索引器实现,无需修改 transformers 源码。

## 功能说明

### SAS (Sparse Attention Shared-KV)

替换 `DeepseekV4Attention.forward` 中的标准注意力计算,使用 mindspeed 提供的融合稀疏注意力核 `SparseAttnSharedKV`,支持三种注意力层类型:

- **Sliding Attention**: 纯滑动窗口注意力
- **CSA (Compressed Sparse Attention)**: 压缩稀疏注意力,使用 Lightning Indexer 选择 top-k 压缩条目
- **HCA (Heavily Compressed Attention)**: 高度压缩注意力,所有压缩条目可见

### LI (Lightning Indexer)

替换 `DeepseekV4Indexer.forward` 中的 torch 实现,使用 mindspeed 提供的 `npu_lightning_indexer` 加速 top-k 索引选择。

**注意**: 当前版本 SAS 和 LI 不能同时启用。

## 依赖

- **[ops-transformer](https://gitcode.com/cann/ops-transformer)**: 提供 NPU 算子实现,需要编译安装
- **[mindspeed](https://gitcode.com/Ascend/MindSpeed)**: 提供 NPU 算子调用实现,需要使用git clone下载mindspeed并切换到master分支进行手动安装
- `mindspeed.ops.npu_sparse_attn_shared_kv.SparseAttnSharedKV` (SAS)
- `mindspeed.ops.npu_lightning_indexer` (LI)
- **transformers**: 需包含 DeepSeek-V4 模型支持
- **torch_npu**: Ascend NPU 运行时

## 环境变量

| 变量 | 默认值 | 说明 |
|------|--------|------|
| `TWINKLE_NPU_DSV4_SAS` | `0` | 启用 SAS patch |
| `TWINKLE_NPU_DSV4_LI` | `0` | 启用 LI patch |

**约束**: `TWINKLE_NPU_DSV4_SAS` 和 `TWINKLE_NPU_DSV4_LI` 不能同时设置为 `1`。

## 使用示例

### 镜像(可选)
```shell
#A3
docker pull swr.cn-southwest-2.myhuaweicloud.com/ascend-sact/twinkle-npu:v4
```

### 启用 SAS

```bash
export TWINKLE_NPU_DSV4_SAS=1
torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py
```

### 启用 LI

```bash
export TWINKLE_NPU_DSV4_LI=1

torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py
```

### 完整示例脚本 (ds16_sas.sh)

```bash
#!/bin/bash
export GLOO_SOCKET_IFNAME="enp162s0f0"
export HCCL_SOCKET_IFNAME="enp162s0f0"
export HCCL_CONNECT_TIMEOUT=7200
export HCCL_EXEC_TIMEOUT=7200
export ACL_DEVICE_SYNC_TIMEOUT=7200
export HCCL_IF_BASE_PORT=30000
export BATCH_SIZE=8
export MAX_STEPS=10
export GRADIENT_CHECKPOINTING=1
export USE_EP=1

# 启用 twinkle SAS patch
export TWINKLE_NPU_DSV4_SAS=1

export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
source /usr/local/Ascend/cann/opp/vendors/custom_transformer/bin/set_env.bash
torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py
```

## 实现原理

Patch 在 `apply_npu_patch()` 阶段自动应用(位于 EP sharding 之后、FSDP wrap 之前),通过以下方式替换原始实现:

1. **Compressor patch**: 包装 `DeepseekV4HCACompressor` 和 `DeepseekV4CSACompressor` 的 `forward` 方法,确保返回 3-tuple `(compressed_kv, block_bias, top_k_indices)`
2. **Attention patch**: 替换 `DeepseekV4Attention.forward`,调用 `SparseAttnSharedKV.apply()` 替代标准注意力 dispatch
3. **Indexer patch**: 替换 `DeepseekV4Indexer.forward`,调用 `mindspeed.ops.npu_lightning_indexer` 替代 torch 实现

所有 patch 均包含 `ImportError` fallback,当 mindspeed 不可用时自动回退到原始实现。

## 验证

运行测试后,检查日志中是否出现:

```
[NPU] [DSV4-SAS] Twinkle sparse attention active (layer_type=..., cmp_ratio=..., topk=...)
```


```
[NPU] [DSV4-LI] Twinkle lightning indexer active (sparse_count=..., cmp_ratio=...)
```

## 相关文件

- `src/twinkle/kernel/deepseek_v4_npu.py`: Patch 核心实现
- `src/twinkle/kernel/monkey_patch_npu.py`: Patch 注册和环境变量控制
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
import twinkle
from peft import LoraConfig
from transformers import AutoConfig
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.kernel import apply_npu_patch
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor


logger = get_logger()
MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')

MAX_LENGTH = int(os.environ.get('MAX_LENGTH', '4096'))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '32'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
LOG_INTERVAL = GRAD_ACCUM_STEPS
LR = float(os.environ.get('LR', '1e-5'))
MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
USE_LORA = os.environ.get('USE_LORA', '1') == '1'
MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1'
GRADIENT_CHECKPOINTING = os.environ.get('GRADIENT_CHECKPOINTING', '1') == '1'
RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
LORA_TARGET_MODULES = os.environ.get(
'LORA_TARGET_MODULES',
'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj',
)
USE_EP = os.environ.get('USE_EP', '0') == '1'
ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
EP_SIZE = BATCH_SIZE if USE_EP else 1
device_mesh = DeviceMesh.from_sizes(
fsdp_size=BATCH_SIZE,
dp_size=1,
ep_size=EP_SIZE,
device_type=Platform.get_platform().device_prefix(),
)

twinkle.initialize(mode='local', global_device_mesh=device_mesh)


def create_dataset(data_slice=None):
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID))
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
return dataset

def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
return model.save(
name=checkpoint_name,
output_dir=OUTPUT_DIR,
adapter_name=ADAPTER_NAME,
save_optimizer=True,
consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
)

def train():
dataset = create_dataset()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
Comment thread
0hujun marked this conversation as resolved.

config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
if hasattr(config, 'use_cache'):
config.use_cache = False
model = TransformersModel(
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
strategy="native_fsdp",
memory_efficient_init=True,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
'expert_parallel': {
'enabled': USE_EP,
'router_dtype': 'fp32',
'keep_router_logits': False,
}
},
)

apply_npu_patch(model)

if USE_LORA:
lora_target_modules = [name.strip() for name in LORA_TARGET_MODULES.split(',') if name.strip()]
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules)
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)

if not GRADIENT_CHECKPOINTING:
model.model.gradient_checkpointing_disable()

model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME)
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=1,
num_training_steps=len(dataloader),
adapter_name=ADAPTER_NAME,
)
optimizer_group = model.optimizer_group[ADAPTER_NAME]
for batch in dataloader:
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch)
model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
cur_step = optimizer_group.cur_step
if cur_step > 0 and cur_step % LOG_INTERVAL == 0:
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}')

final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader)
Comment thread
0hujun marked this conversation as resolved.
logger.info(f'Saved final adapter to {final_checkpoint}')

if __name__ == '__main__':
train()

Loading
Loading