Skip to content

[DS 3.2 ] add ReshapeAndCacheByGroup Ascned ops#7382

Open
ZT-AIA wants to merge 9 commits intovllm-project:mainfrom
ZT-AIA:reshape_ops
Open

[DS 3.2 ] add ReshapeAndCacheByGroup Ascned ops#7382
ZT-AIA wants to merge 9 commits intovllm-project:mainfrom
ZT-AIA:reshape_ops

Conversation

@ZT-AIA
Copy link
Contributor

@ZT-AIA ZT-AIA commented Mar 17, 2026

What this PR does / why we need it?

Optimization of the reshape and cache and scatter and update operators based on the hardware features of Ascend.

Does this PR introduce any user-facing change?

No

How was this patch tested?

zengtian (A) and others added 2 commits March 17, 2026 17:12
Signed-off-by: zengtian (A) <z00893411@china.huawei.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a specialized custom operator, ReshapeAndCacheByGroup, designed to enhance the performance of key-value cache operations specifically on Ascend hardware. By providing a dedicated, optimized implementation for these critical memory management tasks, the change aims to improve the overall efficiency and speed of models running on Ascend platforms. The integration ensures that PyTorch-based workflows can seamlessly utilize this hardware-accelerated functionality, streamlining the process of updating and accessing KV caches.

Highlights

  • New Custom Operator: Introduced a new custom Ascend operator named ReshapeAndCacheByGroup to optimize KV cache management on Ascend hardware.
  • Ascend Integration: Integrated the ReshapeAndCacheByGroup operator into the Ascend build system, including host-side definitions, tiling logic, and AICore kernel implementation for ascend910b and ascend910_93 SOC versions.
  • PyTorch Binding: Exposed the new custom operator to the PyTorch framework, allowing it to be called from Python code.
  • KV Cache Optimization: Updated the vllm_ascend attention mechanism to leverage the new reshape_and_cache_by_group operator for more efficient key-value cache updates, replacing previous reshape_and_cache and scatter_nd_update calls.
  • Unit Testing: Added comprehensive unit tests to validate the correctness and functionality of the ReshapeAndCacheByGroup operator.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/build_aclnn.sh
    • Updated build script to include reshape_and_cache_by_group in custom operations for Ascend910B and Ascend910_93 SOCs.
  • csrc/reshape_and_cache_by_group/add_rms_norm_bias_torch_adpt.h
    • Added a PyTorch adapter for the reshape_and_cache_by_group operator, despite the file name discrepancy.
  • csrc/reshape_and_cache_by_group/op_host/CMakeLists.txt
    • Added CMake configuration for building the ReshapeAndCacheByGroup operator's host-side components.
  • csrc/reshape_and_cache_by_group/op_host/error_log.h
    • Added logging and error checking macros for the optiling namespace, used by the new operator.
  • csrc/reshape_and_cache_by_group/op_host/reshape_and_cache_by_group_common.cpp
    • Added common tiling logic for ReshapeAndCacheByGroup, including parameter parsing and slot mapping.
  • csrc/reshape_and_cache_by_group/op_host/reshape_and_cache_by_group_common.h
    • Added definitions for ReshapeAndCacheByGroupTilingData and ReshapeAndCacheByGroupCommonTiling.
  • csrc/reshape_and_cache_by_group/op_host/reshape_and_cache_by_group_def.cpp
    • Added the operator definition for ReshapeAndCacheByGroup within the CANN framework.
  • csrc/reshape_and_cache_by_group/op_host/reshape_and_cache_by_group_infershape.cpp
    • Added infer shape and data type implementations for the ReshapeAndCacheByGroup operator.
  • csrc/reshape_and_cache_by_group/op_host/reshape_and_cache_by_group_tiling.cpp
    • Added tiling function registration for the ReshapeAndCacheByGroup operator.
  • csrc/reshape_and_cache_by_group/op_kernel/reshape_and_cache_by_group.cpp
    • Added the AICore kernel implementation for reshape_and_cache_by_group.
  • csrc/reshape_and_cache_by_group/op_kernel/reshape_and_cache_by_group.h
    • Added kernel-side definitions for ReshapeAndCacheByGroupTilingData and ReshapeAndCacheByGroupBase.
  • csrc/reshape_and_cache_by_group/tiling_base/data_copy_transpose_tiling.h
    • Added utility functions and structures for data copy and transpose tiling.
  • csrc/reshape_and_cache_by_group/tiling_base/data_copy_transpose_tiling_def.h
    • Added data structure definitions for CopyTransposeTiling parameters.
  • csrc/reshape_and_cache_by_group/tiling_base/error_log.h
    • Added logging and error checking macros for the optiling namespace.
  • csrc/reshape_and_cache_by_group/tiling_base/tiling_base.h
    • Added the base class and framework for operator tiling implementations.
  • csrc/reshape_and_cache_by_group/tiling_base/tiling_key.h
    • Added utilities for generating unique tiling keys.
  • csrc/reshape_and_cache_by_group/tiling_base/tiling_templates_registry.h
    • Added a registry mechanism for managing and selecting tiling templates.
  • csrc/reshape_and_cache_by_group/tiling_base/tiling_type.h
    • Added enums and helper functions for defining tiling parameters and keys.
  • csrc/reshape_and_cache_by_group/tiling_base/tiling_util.h
    • Added utility functions related to tiling context and shape handling.
  • csrc/torch_binding.cpp
    • Modified to register the reshape_and_cache_by_group operator with PyTorch.
  • csrc/torch_binding_meta.cpp
    • Modified to add a meta-function for reshape_and_cache_by_group for PyTorch.
  • tests/ut/ops/test_reshape_and_cachebygroup
    • Added unit tests for the reshape_and_cache_by_group operator, including correctness checks against reference implementations.
  • vllm_ascend/attention/sfa_v1.py
    • Modified attention forward pass to utilize the new torch.ops._C_ascend.reshape_and_cache_by_group custom operator for KV cache updates.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Ascend operator, ReshapeAndCacheByGroup, to optimize caching operations. The changes include the operator's C++ implementation, kernel code, build scripts, and integration into the Python-level attention mechanism. A new unit test is also added to verify its functionality.

My review has identified a critical memory leak and a few high-severity issues related to maintainability and code clarity. Please address these points.

Additionally, per the repository's style guide, I have suggestions for the pull request title and summary to improve clarity and consistency.

Suggested PR Title:

[DS 3.2][Ops][Feature] Add ReshapeAndCacheByGroup Ascend op for optimization

Suggested PR Summary:

### What this PR does / why we need it?
This PR introduces a new Ascend operator, `ReshapeAndCacheByGroup`, to optimize the reshape, cache, and scatter operations. This is designed to leverage hardware features of Ascend NPUs for better performance.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
A new unit test file has been added in `tests/ut/ops/test_reshape_and_cachebygroup` to verify the correctness of the new operator.

// std::cout<<"device "<< idxGroups<<" "<< sizeof(allGroups[0])<<" "<< sizeof(uint32_t)<<" "<<device_size<<" "<<&allGroups[0]<<" "<<&allGroups[0].quotient<<std::endl;
void* devAddr = NULL;

aclrtMalloc(&devAddr, device_size, ACL_MEM_MALLOC_HUGE_FIRST);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a memory leak. aclrtMalloc is called to allocate devAddr, but there is no corresponding call to aclrtFree to release this memory. Since this tiling logic is executed for each operation, this will lead to a gradual memory leak on the device. The allocated memory should be freed after it's no longer needed, likely after the kernel execution completes.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this is crazy

Comment on lines +16 to +17
#ifndef ADD_RMS_NORM_BIAS_TORCH_ADPT_H
#define ADD_RMS_NORM_BIAS_TORCH_ADPT_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The filename add_rms_norm_bias_torch_adpt.h and the header guard ADD_RMS_NORM_BIAS_TORCH_ADPT_H do not match the content of the file, which implements the adapter for reshape_and_cache_by_group. This is misleading and can cause maintenance issues. Please rename the file to reshape_and_cache_by_group_torch_adpt.h and update the header guard accordingly.

Suggested change
#ifndef ADD_RMS_NORM_BIAS_TORCH_ADPT_H
#define ADD_RMS_NORM_BIAS_TORCH_ADPT_H
#ifndef RESHAPE_AND_CACHE_BY_GROUP_TORCH_ADPT_H
#define RESHAPE_AND_CACHE_BY_GROUP_TORCH_ADPT_H

#include "register/tilingdata_base.h"
#include "tiling/tiling_base.h"
// #include "op_log.h"
#include "error_log.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The file error_log.h is included here. However, another error_log.h file with similar content is also added in csrc/reshape_and_cache_by_group/tiling_base/. Having duplicated utility files increases maintenance overhead. It would be better to consolidate them into a single, shared header file in a common utility directory.

)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)[: attn_metadata.num_actual_tokens]
k_pe = k_pe.view(k_pe.shape[0], 1, -1)[: attn_metadata.num_actual_tokens]
zt_block_size=128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block size is hardcoded to 128. This magic number is used here and again on line 1190. It should be avoided. It would be better to define it as a constant or retrieve it from a configuration to improve maintainability and clarity. For instance, it could be part of the model or attention configuration.

@github-actions
Copy link
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: ZT-AIA <1028681969@qq.com>
add_ops_compile_options(
OP_NAME ReshapeAndCacheByGroup
OPTIONS -o0
-g
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove -o0 -g in publish build


void ReshapeAndCacheByGroupCommonTiling::PrintTilingData()
{
// OP_LOGD(context_->GetNodeName(), "Start WriteCacheByGroupListTilingData priting");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use proper debug log or remove the PrintTilingData function

// OP_CHECK_NULL_WITH_CONTEXT(context_, kShape);
auto dim_num=kShape->GetStorageShape().GetDimNum();
if (dim_num<2||dim_num>7){
printf("[ERROR] ReshapeAndCacheByGroup Intput first params dim < 2 || dim_num>7");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with custom op log, this is host stdout!

const gert::RuntimeAttrs *attrs = context_->GetAttrs();
auto slotMapping = attrs->GetListInt(0);
uint32_t slotMappingLen = slotMapping->GetSize();
auto slotMappingData=slotMapping->GetData();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check nullptr for attrs slotMapping and slotMappingData

//slotMapping=[7,8,9,50,51,52,53,54,55,56,57,58,59,30,31,32,33,34,35,36,37,38,39,60,61,62]


auto kcacheShape = context_->GetInputShape(DIM_1);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation

// std::cout<<"device "<< idxGroups<<" "<< sizeof(allGroups[0])<<" "<< sizeof(uint32_t)<<" "<<device_size<<" "<<&allGroups[0]<<" "<<&allGroups[0].quotient<<std::endl;
void* devAddr = NULL;

aclrtMalloc(&devAddr, device_size, ACL_MEM_MALLOC_HUGE_FIRST);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this is crazy

#ifdef ZTDEBUG
std::cout<<"luanxu: "<<j<< slotMappingData[j]<<std::endl;
#endif
j++;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use binary search would improve performance

uint32_t idxGroups = 0;


while (idxSlotmap < slotMappingLen) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the entire SlotMapping-compress logic into the kernel, and make the SlotMapping input a tensor. So you achieve:

Multi-kernel acceleration
No need to copy SlotMapping to the host
Removal of potentially large tiling data (eliminating copy and initialization overhead)
No need to allocate device memory in tiling—there is no proper timing to free it anyway

.DataType({ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("slotMapping").AttrType(OPTIONAL).ListInt({});
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slotMapping can't be optional

ZT-AIA added 6 commits March 19, 2026 10:27
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <1028681969@qq.com>
@yiz-liu yiz-liu added this to the v0.18.0rc1 milestone Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants