Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W8A8
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
Expand Down
33 changes: 29 additions & 4 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,20 @@
from tests.e2e.conftest import VllmRunner

os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

QWEN_DENSE_MODELS = [
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
]

QWEN_W4A8_OLD_VERSION_MODELS = [
"vllm-ascend/Qwen3-8B-W4A8",
]

QWEN_W4A8_NEW_VERSION_MODELS = [
"Anionex/Qwen3-1.7B-W4A8-V1", # TODO: move it into vllm-ascend ci modelscope repo
]

DEEPSEEK_W4A8_MODELS = [
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
Expand Down Expand Up @@ -98,20 +107,36 @@ def test_models_distributed_Qwen3_W8A8():
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_Qwen3_W4A8DYNAMIC():
example_prompts = [
@pytest.mark.parametrize("model", QWEN_W4A8_OLD_VERSION_MODELS)
def test_models_distributed_Qwen3_W4A8DYNAMIC_old_version(model):
prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download(model),
max_model_len=8192,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(prompts, max_tokens)


@pytest.mark.parametrize("model", QWEN_W4A8_NEW_VERSION_MODELS)
def test_models_distributed_Qwen3_W4A8DYNAMIC_new_version(model):
prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-8B-W4A8"),
snapshot_download(model),
max_model_len=8192,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
vllm_model.generate_greedy(prompts, max_tokens)


@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
Expand Down
103 changes: 89 additions & 14 deletions tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,31 @@

class TestAscendW4A8DynamicLinearMethod(TestBase):

def setUp(self):
with patch(
'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
) as mock_get_current_vllm_config:
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_vllm_config.scheduler_config = Mock(
max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
mock_get_current_vllm_config.return_value = mock_vllm_config
self.method = AscendW4A8DynamicLinearMethod()
self.method.group_size = 8
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
mock_get_current_vllm_config.return_value = mock_vllm_config
self.method = AscendW4A8DynamicLinearMethod()
self.method.group_size = 8

def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
# new quant version weight
self.method.new_quant_version = True
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (16, 8))
self.assertEqual(weight["_packed_dim"], 0)
self.assertEqual(weight["_packed_factor"], 2)

def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
Expand All @@ -39,6 +45,75 @@ def test_get_pergroup_param(self):
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
# new quant version weight
self.method.new_quant_version = True
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="column")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 1))
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="row")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 16))

@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
dtype=torch.int32)
# old quant version weight
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
dtype=torch.int8),
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
layer.weight_scale.data),
requires_grad=False)
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.method.new_quant_version = True
new_layer = torch.nn.Module()
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
dtype=torch.int8),
requires_grad=False)
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
new_layer.weight_scale.data),
requires_grad=False)
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset_second = torch.nn.Parameter(
torch.empty_like(new_layer.weight_scale_second.data),
requires_grad=False)
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))


class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
Expand Down
91 changes: 89 additions & 2 deletions tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,30 @@

class TestAscendW4A8DynamicLinearMethod(TestBase):

def setUp(self):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_get_current_vllm_config.return_value = mock_vllm_config
self.method = TorchairAscendW4A8DynamicLinearMethod()
self.method.group_size = 8

def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
# new quant version weight
self.method.new_quant_version = True
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (16, 8))
self.assertEqual(weight["_packed_dim"], 0)
self.assertEqual(weight["_packed_factor"], 2)

def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
Expand All @@ -29,6 +45,75 @@ def test_get_pergroup_param(self):
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
# new quant version weight
self.method.new_quant_version = True
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="column")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 1))
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="row")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 16))

@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
dtype=torch.int32)
# old quant version weight
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
dtype=torch.int8),
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
layer.weight_scale.data),
requires_grad=False)
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.method.new_quant_version = True
new_layer = torch.nn.Module()
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
dtype=torch.int8),
requires_grad=False)
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
new_layer.weight_scale.data),
requires_grad=False)
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset_second = torch.nn.Parameter(
torch.empty_like(new_layer.weight_scale_second.data),
requires_grad=False)
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))


class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
Expand All @@ -42,7 +127,9 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
)
Expand Down
24 changes: 23 additions & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,22 @@ def create_weights(
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)

# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
packed_factor = weight_dict.pop("_packed_factor", None)

for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})

# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})

layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)

Expand All @@ -275,8 +288,17 @@ def create_weights(
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)

# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"

pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition, output_size_per_partition, params_dtype)
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
Expand Down
Loading
Loading