Skip to content

[NPU] [Features] [Bugfix] Support mindiesd adaln#1537

Merged
gcanlin merged 14 commits intovllm-project:mainfrom
jiangmengyu18:support-mindiesd-adaln
Mar 5, 2026
Merged

[NPU] [Features] [Bugfix] Support mindiesd adaln#1537
gcanlin merged 14 commits intovllm-project:mainfrom
jiangmengyu18:support-mindiesd-adaln

Conversation

@jiangmengyu18
Copy link
Contributor

@jiangmengyu18 jiangmengyu18 commented Feb 27, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

  1. accelerate adaln by mindiesd.
  2. fix bug about adaln implemented by torch_npu.

The following is a detailed explanation of the bug about adaln implemented by torch_npu.

First of all, torch_npu.npu_layer_norm_eval is an optimized implementation of torch.nn.functional.layer_norm on NPU. The reason for using torch_npu.npu_layer_norm_eval here is that it is exactly equivalent to AdaLN when batch_size = 1, but it may cause precision issues when batch_size > 1.

Since AdaLN normalizes the last dimension of x, which is d, normalized_shape must be [d], and the shapes of weight and bias must also be [d]. However, the shapes of scale_result and shift_result are [b, 1, d]. Although torch_npu.npu_layer_norm_eval does not raise an error when the shapes of weight and bias are [b, 1, d], it will take weight[0][0] and bias[0][0] for broadcasting, instead of using the full weight and bias. It will lead to precision issues when batch_size > 1.

For example:
init

error result:
error_reuslt

reproduce the error result a second time:
error_reuslt_1

correct reuslt:
correct_result

Test Plan

  • Qwen-Image-Edit-2509

Test Result

The table below shows the time consumption of AdaLN when using native, torch_npu, and mindiesd, respectively.

native torch_npu mindiesd
36.8s 36.7s 35.8s

Performance:
native:
output_native_adaln

torch_npu:
output_torchnpu_adaln

mindiesd:
output_mindiesd_adaln


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please providing the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please pasting the results comparison before and after, or e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

jiangmengyu18 and others added 4 commits February 27, 2026 09:18
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8b3b5a0972

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

jiangmengyu18 and others added 4 commits February 27, 2026 15:31
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: jiangmengyu18 <451528648@qq.com>
@jiangmengyu18 jiangmengyu18 changed the title Support mindiesd adaln [NPU] [Features] [Bugfix] Support mindiesd adaln Feb 27, 2026
Copy link
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good now. Thanks!
cc @hsliuustc0106

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Feb 28, 2026
Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bugfix looks right -- the old code passing scale/shift as weight/bias to npu_layer_norm_eval was definitely wrong for batch > 1. Left one comment on the exception handling.

Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
Signed-off-by: Hongsheng Liu <liuhongsheng4@huawei.com>
import torch_npu

output = torch_npu.npu_layer_norm_eval(
x, normalized_shape=[self.hidden_size], weight=(1 + scale_result), bias=shift_result, eps=self.eps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated line. Remove one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the rebundant logger.warning_once? I have pull request to solve it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bugfix looks correct — applying scale/shift outside npu_layer_norm_eval avoids the silent broadcasting issue. Left a nit about a duplicated line.

Signed-off-by: jiangmengyu18 <451528648@qq.com>
@gcanlin gcanlin merged commit 5493e75 into vllm-project:main Mar 5, 2026
7 checks passed
ahengljh pushed a commit to ahengljh/vllm-omni that referenced this pull request Mar 5, 2026
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
linyueqian pushed a commit to lishunyang12/vllm-omni that referenced this pull request Mar 5, 2026
Signed-off-by: jiangmengyu18 <451528648@qq.com>
Signed-off-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants