Skip to content

Commit 3fb80ee

Browse files
momo609wangxiaoxin-sherie
andauthored
add mlp tp optimze (#2120)
### What this PR does / why we need it? For dense models, by not applying tensor parallelism (TP) to the attention module and applying TP to the MLP module, the allreduce operations in the attention module can be eliminated, thereby reducing computational overhead. However, this approach increases memory usage, so the environment variable VLLM_ASCEND_ENABLE_MLP_OPTIMZE is used to control this optimization. - vLLM main: vllm-project/vllm@b17109b Signed-off-by: wangxiaoxin-sherie <[email protected]> Co-authored-by: wangxiaoxin-sherie <[email protected]>
1 parent 973a7cf commit 3fb80ee

File tree

6 files changed

+729
-2
lines changed

6 files changed

+729
-2
lines changed

tests/ut/ops/test_linear.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import os
2+
import unittest
3+
from unittest import mock
4+
5+
import torch
6+
7+
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
8+
AscendMlpMergedColumnParallelLinear,
9+
AscendMlpRowParallelLinear, LinearBase,
10+
QuantizationConfig)
11+
12+
13+
class TestAscendMlpRowParallelLinear(unittest.TestCase):
14+
15+
def setUp(self):
16+
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
17+
self.tensor_parallel_world_size = 2
18+
self.tensor_parallel_rank = 0
19+
self.mlp_tensor_parallel_world_size = 2
20+
self.mlp_tensor_parallel_rank = 1
21+
22+
self.get_tensor_model_parallel_world_size_patch = mock.patch(
23+
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
24+
return_value=self.tensor_parallel_world_size)
25+
self.get_tensor_model_parallel_rank_patch = mock.patch(
26+
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
27+
return_value=self.tensor_parallel_rank)
28+
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
29+
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
30+
return_value=self.mlp_tensor_parallel_world_size)
31+
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
32+
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
33+
return_value=self.mlp_tensor_parallel_rank)
34+
35+
self.get_tensor_model_parallel_world_size_mock = \
36+
self.get_tensor_model_parallel_world_size_patch.start()
37+
self.get_tensor_model_parallel_rank_mock = \
38+
self.get_tensor_model_parallel_rank_patch.start()
39+
self.get_mlp_tensor_model_parallel_world_size_mock = \
40+
self.get_mlp_tensor_model_parallel_world_size_patch.start()
41+
self.get_mlp_tensor_model_parallel_rank_mock = \
42+
self.get_mlp_tensor_model_parallel_rank_patch.start()
43+
44+
self.split_tensor_along_last_dim_patch = mock.patch(
45+
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
46+
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
47+
self.tensor_model_parallel_all_reduce_patch = mock.patch(
48+
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
49+
return_value=torch.randn(10, 8))
50+
self.tensor_model_parallel_all_reduce_mock = \
51+
self.tensor_model_parallel_all_reduce_patch.start()
52+
self.split_tensor_along_last_dim_mock = \
53+
self.split_tensor_along_last_dim_patch.start()
54+
self.get_mlp_tp_group_patch = \
55+
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
56+
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
57+
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
58+
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
59+
mock.MagicMock()
60+
61+
def tearDown(self):
62+
self.get_tensor_model_parallel_world_size_patch.stop()
63+
self.get_tensor_model_parallel_rank_patch.stop()
64+
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
65+
self.get_mlp_tensor_model_parallel_rank_patch.stop()
66+
self.split_tensor_along_last_dim_patch.stop()
67+
self.tensor_model_parallel_all_reduce_patch.stop()
68+
self.get_mlp_tp_group_patch.stop()
69+
70+
def test_init_with_down_proj_prefix(self):
71+
layer = AscendMlpRowParallelLinear(input_size=16,
72+
output_size=8,
73+
prefix="down_proj")
74+
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
75+
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
76+
self.assertTrue(layer.enable_mlp_optimze)
77+
78+
def test_forward_with_mlp_optimize(self):
79+
layer = AscendMlpRowParallelLinear(
80+
input_size=16,
81+
output_size=8,
82+
prefix="down_proj",
83+
input_is_parallel=False,
84+
)
85+
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
86+
layer(input_tensor)
87+
88+
self.split_tensor_along_last_dim_mock.assert_called_once_with(
89+
input_tensor, num_partitions=layer.tp_size)
90+
91+
def test_forward_without_mlp_optimize(self):
92+
layer = AscendMlpRowParallelLinear(
93+
input_size=16,
94+
output_size=8,
95+
prefix="other",
96+
input_is_parallel=False,
97+
)
98+
input_tensor = torch.randn(16, 8)
99+
layer(input_tensor)
100+
101+
self.split_tensor_along_last_dim_mock.assert_called_once_with(
102+
input_tensor, num_partitions=layer.tp_size)
103+
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
104+
105+
def test_skip_bias_add(self):
106+
layer = AscendMlpRowParallelLinear(
107+
input_size=16,
108+
output_size=8,
109+
skip_bias_add=True,
110+
)
111+
input_tensor = torch.randn(16, 8)
112+
output, bias = layer(input_tensor)
113+
114+
self.assertIsNotNone(bias)
115+
116+
def test_no_reduce_results(self):
117+
layer = AscendMlpRowParallelLinear(input_size=16,
118+
output_size=8,
119+
reduce_results=False,
120+
bias=False)
121+
input_tensor = torch.randn(16, 8)
122+
layer(input_tensor)
123+
124+
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
125+
126+
def test_input_not_parallel(self):
127+
layer = AscendMlpRowParallelLinear(input_size=16,
128+
output_size=8,
129+
input_is_parallel=False)
130+
input_tensor = torch.randn(16, 8)
131+
layer(input_tensor)
132+
133+
self.split_tensor_along_last_dim_mock.assert_called_once()
134+
135+
def test_exception_when_reduce_false_and_bias(self):
136+
with self.assertRaises(ValueError):
137+
AscendMlpRowParallelLinear(input_size=16,
138+
output_size=8,
139+
reduce_results=False,
140+
bias=True,
141+
skip_bias_add=False)
142+
143+
144+
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
145+
146+
def setUp(self):
147+
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
148+
# Mock distributed functions
149+
self.mlp_tp_size_patch = \
150+
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
151+
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
152+
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
153+
154+
self.mlp_tp_rank_patch = \
155+
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
156+
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
157+
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
158+
159+
self.tp_size_patch = \
160+
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
161+
self.tp_size_mock = self.tp_size_patch.start()
162+
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
163+
164+
self.tp_rank_patch = \
165+
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
166+
self.tp_rank_mock = self.tp_rank_patch.start()
167+
self.tp_rank_mock.return_value = 1 # Current GPU rank
168+
169+
# Mock divide function (assumed to be in your module)
170+
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
171+
self.divide_mock = self.divide_patch.start()
172+
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
173+
174+
# Mock QuantizationConfig and QuantMethod
175+
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
176+
177+
# Mock LinearBase initialization
178+
self.linear_base_init_patch = mock.patch.object(
179+
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
180+
self.linear_base_init_patch.start()
181+
182+
self.quant_method_mock = mock.MagicMock()
183+
184+
def mock_linear_base_init(self, instance, *args, **kwargs):
185+
instance.quant_method = self.quant_method_mock
186+
instance.params_dtype = mock.MagicMock()
187+
188+
instance.input_size = 16
189+
instance.output_size = 8
190+
instance.output_size_per_partition = 4
191+
instance.params_dtype = torch.float32
192+
193+
def tearDown(self):
194+
self.mlp_tp_size_patch.stop()
195+
self.mlp_tp_rank_patch.stop()
196+
self.tp_size_patch.stop()
197+
self.tp_rank_patch.stop()
198+
self.divide_patch.stop()
199+
self.linear_base_init_patch.stop()
200+
201+
def test_mlp_optimize_initialization(self):
202+
# Test when prefix contains "gate_up_proj"
203+
with mock.patch.object(torch.nn.Module, 'register_parameter'):
204+
layer = AscendMlpColumnParallelLinear(
205+
input_size=16,
206+
output_size=8,
207+
prefix="model.layers.0.gate_up_proj",
208+
bias=False,
209+
)
210+
211+
# Verify MLP optimization flags
212+
self.assertTrue(layer.enable_mlp_optimze)
213+
self.assertEqual(layer.tp_size, 2)
214+
self.assertEqual(layer.tp_rank, 0)
215+
self.assertEqual(layer.input_size_per_partition, 16)
216+
self.assertEqual(layer.output_size_per_partition, 4)
217+
218+
# Check quant_method.create_weights was called
219+
self.quant_method_mock.create_weights.assert_called_once()
220+
221+
def test_regular_parallel_initialization(self):
222+
# Test when prefix does NOT contain "gate_up_proj"
223+
with mock.patch.object(torch.nn.Module, 'register_parameter'):
224+
layer = AscendMlpColumnParallelLinear(
225+
input_size=16,
226+
output_size=8,
227+
prefix="model.layers.0.q_proj",
228+
quant_config=self.quant_config_mock,
229+
bias=False,
230+
)
231+
232+
# Verify regular TP flags
233+
self.assertFalse(layer.enable_mlp_optimze)
234+
self.assertEqual(layer.tp_size, 4)
235+
self.assertEqual(layer.tp_rank, 1)
236+
self.assertEqual(layer.input_size_per_partition, 16)
237+
self.assertEqual(layer.output_size_per_partition, 4)
238+
# Check quant_method.create_weights was called
239+
self.quant_method_mock.create_weights.assert_called_once()
240+
241+
def test_output_sizes_handling(self):
242+
# Test when output_sizes is provided
243+
with mock.patch.object(torch.nn.Module, 'register_parameter'):
244+
layer = AscendMlpColumnParallelLinear(
245+
input_size=16,
246+
output_size=8,
247+
output_sizes=[4, 4],
248+
prefix="model.layers.0.qkv_proj",
249+
quant_config=self.quant_config_mock,
250+
bias=False,
251+
)
252+
253+
# Verify output_partition_sizes
254+
self.assertEqual(layer.output_partition_sizes, [2])
255+
256+
257+
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
258+
259+
def setUp(self):
260+
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
261+
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
262+
self.mlp_world_size_patch = \
263+
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
264+
self.tensor_world_size_patch = \
265+
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
266+
self.mlp_world_size_patch.start()
267+
self.tensor_world_size_patch.start()
268+
269+
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
270+
self.mlp_rank_patch = \
271+
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
272+
self.tensor_rank_patch = \
273+
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
274+
self.mlp_rank_patch.start()
275+
self.tensor_rank_patch.start()
276+
277+
# Mock all_gather methods
278+
self.get_mlp_tp_group_patch = \
279+
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
280+
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
281+
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
282+
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
283+
self.tensor_model_parallel_all_gather_patch = mock.patch(
284+
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
285+
return_value=torch.randn(10, 8))
286+
self.tensor_model_parallel_all_gather_mock = \
287+
self.tensor_model_parallel_all_gather_patch.start()
288+
289+
# Mock AscendMlpColumnParallelLinear's __init__
290+
self.linear_init_patch = mock.patch.object(
291+
AscendMlpColumnParallelLinear,
292+
"__init__",
293+
side_effect=self.mock_linear_init)
294+
self.linear_init_patch.start()
295+
296+
# Create mock objects
297+
self.quant_method_mock = mock.MagicMock()
298+
self.apply_output = torch.randn(2, 8)
299+
300+
self.quant_method_mock.apply.return_value = self.apply_output
301+
302+
def mock_linear_init(self, instance, *args, **kwargs):
303+
torch.nn.Module.__init__(instance)
304+
# Set quant_method and other attributes
305+
instance.quant_method = self.quant_method_mock
306+
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
307+
instance.input_size = 16
308+
instance.output_size = 8
309+
instance.gather_output = False
310+
instance.skip_bias_add = False
311+
instance.return_bias = True
312+
313+
def test_forward_with_enable_mlp_optimze(self):
314+
# Setup input
315+
input_tensor = torch.randn(1, 16)
316+
317+
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
318+
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
319+
output_sizes=[8],
320+
bias=True,
321+
gather_output=False,
322+
skip_bias_add=False,
323+
params_dtype=torch.float32,
324+
quant_config=None,
325+
prefix="other_proj")
326+
327+
# Call forward
328+
output, bias = layer(input_tensor)
329+
330+
# Validate calls
331+
self.assertEqual(output.shape, self.apply_output.shape)
332+
333+
def test_forward_without_enable_mlp_optimze(self):
334+
# Setup input
335+
input_tensor = torch.randn(1, 16)
336+
337+
# Create instance with prefix not containing "gate_up_proj"
338+
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
339+
output_sizes=[8],
340+
bias=True,
341+
gather_output=False,
342+
skip_bias_add=False,
343+
params_dtype=torch.float32,
344+
quant_config=None,
345+
prefix="other_proj")
346+
347+
# Call forward
348+
output, bias = layer(input_tensor)
349+
350+
# Validate calls
351+
self.quant_method_mock.apply.assert_called_once_with(
352+
layer, input_tensor, layer.bias)
353+
self.tensor_model_parallel_all_gather_mock.assert_not_called()
354+
self.assertEqual(output.shape, self.apply_output.shape)
355+
356+
def tearDown(self):
357+
self.linear_init_patch.stop()
358+
self.mlp_world_size_patch.stop()
359+
self.tensor_world_size_patch.stop()
360+
self.mlp_rank_patch.stop()
361+
self.tensor_rank_patch.stop()
362+
self.get_mlp_tp_group_mock.stop()
363+
self.tensor_model_parallel_all_gather_mock.stop()

tests/ut/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
356356
# ascend custom op is not registered
357357
utils.register_ascend_customop()
358358
# should call register_oot three
359-
self.assertEqual(mock_customop.register_oot.call_count, 3)
359+
self.assertEqual(mock_customop.register_oot.call_count, 6)
360360
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
361361

362362
# ascend custom op is already registered
363363
utils.register_ascend_customop()
364364
# should not register_oot again, thus only called three in this ut
365-
self.assertEqual(mock_customop.register_oot.call_count, 3)
365+
self.assertEqual(mock_customop.register_oot.call_count, 6)
366366

367367

368368
class TestProfileExecuteDuration(TestBase):

0 commit comments

Comments
 (0)