Skip to content

Commit de345b5

Browse files
committed
refact model runner
Signed-off-by: weiguihua2 <[email protected]>
1 parent 38c5dea commit de345b5

File tree

3 files changed

+21
-34
lines changed

3 files changed

+21
-34
lines changed

vllm_ascend/lora/punica_wrapper/lora_ops.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor,
5252
slice_offset: int,
5353
slice_size: int,
5454
add_inputs: bool = True):
55-
return torch.ops._C.bgmv_expand(
56-
inputs,
57-
lora_b_weights,
58-
lora_indices_tensor,
59-
output_tensor,
60-
slice_offset,
61-
slice_size
62-
)
55+
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
56+
lora_indices_tensor, output_tensor,
57+
slice_offset, slice_size)
6358

6459

6560
def sgmv_shrink(
@@ -74,8 +69,9 @@ def sgmv_shrink(
7469
token_nums: int,
7570
scaling: float,
7671
):
77-
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
78-
seq_len_tensor, output_tensor, scaling)
72+
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
73+
lora_indices_tensor, seq_len_tensor,
74+
output_tensor, scaling)
7975

8076

8177
def sgmv_expand(inputs: torch.Tensor,
@@ -111,12 +107,6 @@ def sgmv_expand_slice(inputs: torch.Tensor,
111107
slice_offset: int,
112108
slice_size: int,
113109
add_inputs: bool = False):
114-
return torch.ops._C.sgmv_expand(
115-
inputs,
116-
lora_b_weights,
117-
lora_indices_tensor,
118-
seq_len_tensor,
119-
output_tensor,
120-
slice_offset,
121-
slice_size
122-
)
110+
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
111+
lora_indices_tensor, seq_len_tensor,
112+
output_tensor, slice_offset, slice_size)

vllm_ascend/meta_registration.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,18 @@ def get_masked_input_and_mask_meta(input: torch.Tensor,
8080

8181
return masked_input, mask
8282

83-
def bgmv_expand_meta(x: torch.Tensor,
84-
weight: torch.Tensor,
85-
indices: torch.Tensor,
86-
y: torch.Tensor,
87-
slice_offset: int,
88-
slice_size: int):
83+
84+
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
85+
indices: torch.Tensor, y: torch.Tensor, slice_offset: int,
86+
slice_size: int):
8987

9088
y_out = torch.empty_like(y)
9189
return y_out
9290

93-
def sgmv_expand_meta(x: torch.Tensor,
94-
weight: torch.Tensor,
95-
lora_indices: torch.Tensor,
96-
seq_len: torch.Tensor,
97-
y: torch.Tensor,
98-
slice_offset: int,
99-
slice_size: int):
91+
92+
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
93+
lora_indices: torch.Tensor, seq_len: torch.Tensor,
94+
y: torch.Tensor, slice_offset: int, slice_size: int):
10095

10196
y_out = torch.empty_like(y)
10297
return y_out

vllm_ascend/torchair/torchair_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,11 @@ class AscendTorchairMetadata(AscendMetadata):
9898

9999
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
100100

101-
def __init__(self,
101+
def __init__(
102+
self,
102103
vllm_config: VllmConfig,
103-
device: torch.device,):
104+
device: torch.device,
105+
):
104106
super().__init__(vllm_config, device)
105107
self.max_num_blocks_per_req = cdiv(
106108
self.model_config.max_model_len,

0 commit comments

Comments
 (0)