@@ -78,10 +78,6 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
78
78
...], scale : float , ** kwargs ):
79
79
"""
80
80
Performs GEMM for multiple slices of lora_a.
81
- When `is_prefill is` true, it indicates that it is currently the
82
- prefill stage, and the `_shrink_prefill` function should be called.
83
- Otherwise, it is the decode stage, and the _shrink_decode function
84
- should be called.
85
81
86
82
Semantics:
87
83
for i in range(len(lora_a_stacked)):
@@ -129,7 +125,7 @@ def add_expand(self,
129
125
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
130
126
bias's weight
131
127
output_slices (Tuple[int, ...]): Every slice's size
132
- add_inputs (bool): Defaults to True.
128
+ add_inputs (bool): Defaults to True.
133
129
"""
134
130
y_org = y
135
131
y = y .view (- 1 , y .shape [- 1 ])
@@ -226,7 +222,7 @@ def add_lora_linear(self,
226
222
227
223
if buffer is None :
228
224
r = lora_b_stacked [0 ].size (- 1 )
229
- # We set the buffer to be float32 by default , refer to:
225
+ # We set the buffer to be float32 by default, refer to:
230
226
# https://github.com/triton-lang/triton/issues/1387
231
227
buffer = torch .zeros ( # type: ignore
232
228
(len (output_slices ), x .size (0 ), r ),
@@ -268,16 +264,16 @@ def add_lora_logits(self,
268
264
y (torch.Tensor): Output tensor.
269
265
x (torch.Tensor): Input tensor.
270
266
lora_a_stacked (torch.Tensor): lora_a's weights.
271
- lora_b_stacked (torch.Tensor):lora_b's weights.
267
+ lora_b_stacked (torch.Tensor): lora_b's weights.
272
268
scale (float): Scaling factor.
273
- buffer (Optional[torch.Tensor]):Default to None.
269
+ buffer (Optional[torch.Tensor]): Default to None.
274
270
"""
275
271
y_org = y
276
272
y = y .view (- 1 , y .shape [- 1 ])
277
273
x = x .view (- 1 , x .shape [- 1 ])
278
274
r = lora_b_stacked .size (- 1 )
279
275
if buffer is None :
280
- # We set the buffer to be float32 by default , refer to:
276
+ # We set the buffer to be float32 by default, refer to:
281
277
# https://github.com/triton-lang/triton/issues/1387
282
278
buffer = torch .zeros ((x .size (0 ), r ),
283
279
dtype = torch .float32 ,
0 commit comments