Skip to content

Commit d8fa1a8

Browse files
committed
Support partial_rotary_factor (Phi-4 mini)
1 parent 2e630ae commit d8fa1a8

File tree

10 files changed

+63
-36
lines changed

10 files changed

+63
-36
lines changed

exllamav2/attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def load(self, device_context: bool = True):
317317
cfg.max_seq_len,
318318
self.has_residual,
319319
self.archparams.rope_style.value,
320+
int(cfg.head_dim * cfg.partial_rotary_factor),
320321
q_norm,
321322
k_norm,
322323
post_norm_weight,

exllamav2/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class ExLlamaV2Config:
127127
checkpoint_offset_qzeros: bool
128128
mrope_section: list | None
129129
attention_multiplier: float | None
130+
partial_rotary_factor: float | None
130131

131132
vision_model_type: str | None
132133
vision_head_dim: int | None
@@ -361,6 +362,8 @@ def prepare(self, no_tensors: bool = False):
361362
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0, opt_subkey = "text_config")
362363
self.sliding_window_pattern = read(read_config, int, ["sliding_window_pattern"], 1)
363364

365+
self.partial_rotary_factor = read(read_config, float, "partial_rotary_factor", 1.0)
366+
364367
rs = read(read_config, dict, "rope_scaling", None)
365368
if rs:
366369
scaling_type = rs.get("type", None)

exllamav2/exllamav2_ext/cuda/q_attn.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ QAttn::QAttn
104104
int _max_seq_len,
105105
bool _has_residual,
106106
int _rope_style,
107+
int _sincos_size,
107108
half* _q_norm,
108109
half* _k_norm,
109110
half* _post_layernorm,
@@ -132,6 +133,7 @@ QAttn::QAttn
132133
max_seq_len(_max_seq_len),
133134
has_residual(_has_residual),
134135
rope_style(_rope_style),
136+
sincos_size(_sincos_size),
135137
q_norm(_q_norm),
136138
k_norm(_k_norm),
137139
post_layernorm(_post_layernorm),
@@ -305,6 +307,7 @@ void QAttn::forward_cuda_1_run
305307
past_len,
306308
past_lens,
307309
rope_style == ROPE_STYLE_NEOX,
310+
sincos_size,
308311
graph,
309312
KernelLabels::ROPE
310313
);

exllamav2/exllamav2_ext/cuda/q_attn.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ public:
7676
bool has_residual;
7777
bool residual_fp32;
7878
int rope_style;
79+
int sincos_size;
7980

8081
bool use_graphs;
8182
std::unordered_map<QAttn_params_const, Graph*, QAttn_params_const_hash> graph_map;
@@ -103,6 +104,7 @@ public:
103104
int _max_seq_len,
104105
bool _has_residual,
105106
int _rope_style,
107+
int _sincos_size,
106108
half* _q_norm,
107109
half* _k_norm,
108110
half* _post_layernorm,

exllamav2/exllamav2_ext/cuda/rope.cu

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ __forceinline__ __device__ void rope_cuda_arr_neox
1717
int num_heads,
1818
int past_len,
1919
const int32_t* __restrict__ past_lens,
20-
int threads_y
20+
int threads_y,
21+
int sincos_size
2122
)
2223
{
2324
MatrixView_half_rw x_(x, MAX_ROWS, head_dim);
24-
MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, head_dim);
25-
MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, head_dim);
25+
MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, sincos_size);
26+
MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, sincos_size);
2627

2728
int column = (blockIdx.x * THREADS_X + threadIdx.x) * 2;
28-
int half_dim = head_dim / 2;
29+
int half_dim = sincos_size / 2;
2930
if (column >= half_dim) return;
3031

3132
int row = blockIdx.y * threads_y + threadIdx.y;
@@ -76,15 +77,16 @@ __forceinline__ __device__ void rope_cuda_arr_gptj
7677
int num_heads,
7778
int past_len,
7879
const int32_t* __restrict__ past_lens,
79-
int threads_y
80+
int threads_y,
81+
int sincos_size
8082
)
8183
{
8284
MatrixView_half_rw x_(x, MAX_ROWS, head_dim);
83-
MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, head_dim);
84-
MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, head_dim);
85+
MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, sincos_size);
86+
MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, sincos_size);
8587

8688
int column = (blockIdx.x * THREADS_X + threadIdx.x) * 2;
87-
if (column >= head_dim) return;
89+
if (column >= sincos_size) return;
8890

8991
int row = blockIdx.y * threads_y + threadIdx.y;
9092
if (row >= rows_per_batch) return;
@@ -131,13 +133,14 @@ __global__ void rope_cuda_kernel
131133
int past_len,
132134
const int32_t* __restrict__ past_lens,
133135
int threads_y,
134-
const bool neox_style
136+
const bool neox_style,
137+
int sincos_size
135138
)
136139
{
137140
if (neox_style)
138-
rope_cuda_arr_neox(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len, past_lens, threads_y);
141+
rope_cuda_arr_neox(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len, past_lens, threads_y, sincos_size);
139142
else
140-
rope_cuda_arr_gptj(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len, past_lens, threads_y);
143+
rope_cuda_arr_gptj(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len, past_lens, threads_y, sincos_size);
141144
}
142145

143146
__global__ void rope_cuda_qk_kernel
@@ -154,18 +157,19 @@ __global__ void rope_cuda_qk_kernel
154157
int past_len,
155158
const int32_t* __restrict__ past_lens,
156159
int threads_y,
157-
const bool neox_style
160+
const bool neox_style,
161+
int sincos_size
158162
)
159163
{
160164
if (neox_style)
161165
{
162-
rope_cuda_arr_neox(x_q, sin, cos, rows_per_batch_q, head_dim, num_heads_q, past_len, past_lens, threads_y);
163-
rope_cuda_arr_neox(x_k, sin, cos, rows_per_batch_k, head_dim, num_heads_k, past_len, past_lens, threads_y);
166+
rope_cuda_arr_neox(x_q, sin, cos, rows_per_batch_q, head_dim, num_heads_q, past_len, past_lens, threads_y, sincos_size);
167+
rope_cuda_arr_neox(x_k, sin, cos, rows_per_batch_k, head_dim, num_heads_k, past_len, past_lens, threads_y, sincos_size);
164168
}
165169
else
166170
{
167-
rope_cuda_arr_gptj(x_q, sin, cos, rows_per_batch_q, head_dim, num_heads_q, past_len, past_lens, threads_y);
168-
rope_cuda_arr_gptj(x_k, sin, cos, rows_per_batch_k, head_dim, num_heads_k, past_len, past_lens, threads_y);
171+
rope_cuda_arr_gptj(x_q, sin, cos, rows_per_batch_q, head_dim, num_heads_q, past_len, past_lens, threads_y, sincos_size);
172+
rope_cuda_arr_gptj(x_k, sin, cos, rows_per_batch_k, head_dim, num_heads_k, past_len, past_lens, threads_y, sincos_size);
169173
}
170174
}
171175

@@ -181,7 +185,8 @@ void rope_cuda
181185
const int num_heads,
182186
const int past_len,
183187
const int32_t* past_lens,
184-
const bool neox_style
188+
const bool neox_style,
189+
int sincos_size
185190
)
186191
{
187192
// For large batch sizes we risk exceeding grid dimension of 65535, so shift to block dimension instead
@@ -207,7 +212,8 @@ void rope_cuda
207212
past_len,
208213
past_lens,
209214
threads_y,
210-
neox_style
215+
neox_style,
216+
sincos_size
211217
);
212218
}
213219

@@ -227,6 +233,7 @@ void rope_cuda_qk
227233
const int past_len,
228234
const int32_t* past_lens,
229235
const bool neox_style,
236+
int sincos_size,
230237
Graph* graph,
231238
int label
232239
)
@@ -258,7 +265,8 @@ void rope_cuda_qk
258265
past_len,
259266
past_lens,
260267
threads_y,
261-
neox_style
268+
neox_style,
269+
sincos_size
262270
);
263271

264272
if (graph) graph->attach_label(stream, label, 0);

exllamav2/exllamav2_ext/cuda/rope.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ void rope_cuda
1919
const int num_heads,
2020
const int past_len,
2121
const int32_t* past_lens,
22-
const bool neox_style
22+
const bool neox_style,
23+
int sincos_size
2324
);
2425

2526
void rope_cuda_qk
@@ -38,6 +39,7 @@ void rope_cuda_qk
3839
const int past_len,
3940
const int32_t* past_lens,
4041
const bool neox_style,
42+
int sincos_size,
4143
Graph* graph = NULL,
4244
int label = 0
4345
);

exllamav2/exllamav2_ext/ext_qattn.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ uintptr_t make_q_attn
4444
int max_seq_len,
4545
bool has_residual,
4646
int rope_style,
47+
int sincos_size,
4748
torch::Tensor q_norm,
4849
torch::Tensor k_norm,
4950
torch::Tensor post_layernorm,
@@ -88,6 +89,7 @@ uintptr_t make_q_attn
8889
max_seq_len,
8990
has_residual,
9091
rope_style,
92+
sincos_size,
9193
q_norm.is_meta() ? NULL : (half*) q_norm.data_ptr(),
9294
k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr(),
9395
post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(),
@@ -377,7 +379,8 @@ void tp_attn_forward_paged_
377379
num_kv_heads,
378380
0, //past_len,
379381
(int32_t*) past_lens[i].data_ptr(),
380-
rope_style == ROPE_STYLE_NEOX
382+
rope_style == ROPE_STYLE_NEOX,
383+
head_dim // TODO: partial_rotary_factor
381384
);
382385
}
383386
}
@@ -613,7 +616,8 @@ void tp_attn_forward_
613616
num_kv_heads,
614617
0, //past_len,
615618
(int32_t*) past_len_tp[i].data_ptr(),
616-
rope_style == ROPE_STYLE_NEOX
619+
rope_style == ROPE_STYLE_NEOX,
620+
head_dim // TODO: partial_rotary_factor
617621
);
618622
}
619623
}

exllamav2/exllamav2_ext/ext_qattn.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ uintptr_t make_q_attn
2222
int max_seq_len,
2323
bool has_residual,
2424
int rope_style,
25+
int sincos_size,
2526
torch::Tensor q_norm,
2627
torch::Tensor k_norm,
2728
torch::Tensor post_layernorm,

exllamav2/exllamav2_ext/ext_rope.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ void rope_
3232
TORCH_CHECK_DTYPE(x, kHalf);
3333
TORCH_CHECK_DTYPE(sin, kHalf);
3434
TORCH_CHECK_DTYPE(cos, kHalf);
35-
TORCH_CHECK(head_dim == cos.size(-1), "cos table does not match head_dim");
36-
TORCH_CHECK(head_dim == sin.size(-1), "sin table does not match head_dim");
35+
// TORCH_CHECK(head_dim == cos.size(-1), "cos table does not match head_dim");
36+
// TORCH_CHECK(head_dim == sin.size(-1), "sin table does not match head_dim");
37+
TORCH_CHECK(cos.size(-1) == sin.size(-1), "sin table does not cos table");
3738
TORCH_CHECK_DTYPE_OPT(offsets, kInt);
3839

3940
int batch_size = x.size(0);
4041
int rows_per_batch = x.numel() / head_dim / batch_size;
42+
int sincos_size = cos.size(-1);
4143

4244
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
4345
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
@@ -54,7 +56,8 @@ void rope_
5456
num_heads,
5557
past_len,
5658
offsets.device().is_meta() ? NULL : (int32_t*) offsets.data_ptr(),
57-
neox_style
59+
neox_style,
60+
sincos_size
5861
);
5962
}
6063

exllamav2/rope.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_rope_params_su(
1313
device: torch.Device,
1414
cfg: ExLlamaV2Config,
1515
):
16-
head_dim = cfg.head_dim
16+
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
1717
base = cfg.rotary_embedding_base
1818
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
1919
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -36,7 +36,7 @@ def get_rope_params_llama3(
3636
device: torch.Device,
3737
cfg: ExLlamaV2Config,
3838
):
39-
head_dim = cfg.head_dim
39+
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
4040
base = cfg.rotary_embedding_base
4141
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
4242
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -81,7 +81,8 @@ def get_rope_params_yarn(
8181
device: torch.Device,
8282
cfg: ExLlamaV2Config,
8383
):
84-
head_dim = cfg.head_dim
84+
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
85+
8586
base = cfg.rotary_embedding_base
8687
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
8788
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -91,9 +92,7 @@ def get_rope_params_yarn(
9192
# Only activate if longer than original ctx
9293
if cfg.max_seq_len > cfg.yarn_rope_original_max_position_embeddings:
9394

94-
partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
95-
dim = int(head_dim * partial_rotary_factor)
96-
95+
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
9796
factor = cfg.yarn_rope_factor
9897

9998
# Sets the attention factor as suggested in the paper
@@ -126,14 +125,14 @@ def linear_ramp_factor(min, max, dim):
126125

127126
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
128127
# to expand the possible context length. In other words, interpolation = apply scaling factor.
129-
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
128+
pos_freqs = base ** (torch.arange(0, head_dim, 2).float().to(device) / head_dim)
130129
inv_freq_extrapolation = 1.0 / pos_freqs
131130
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
132131

133-
low, high = find_correction_range(beta_fast, beta_slow, dim, base, yarn_max_position_embeddings)
132+
low, high = find_correction_range(beta_fast, beta_slow, head_dim, base, yarn_max_position_embeddings)
134133

135134
# Get n-dimensional rotational scaling corrected for extrapolation
136-
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
135+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).float().to(device)
137136
inv_freq = (
138137
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
139138
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
@@ -150,10 +149,11 @@ def get_rope_params_default(
150149
device: torch.Device,
151150
cfg: ExLlamaV2Config,
152151
):
153-
head_dim = cfg.head_dim
152+
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
153+
154154
base = cfg.rotary_embedding_base
155155
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
156-
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
156+
base *= cfg.scale_alpha_value ** (head_dim / (head_dim - 2))
157157

158158
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
159159
return inv_freq, 1.0

0 commit comments

Comments
 (0)