Skip to content

Commit 663eea1

Browse files
committed
Fix 64-bit dtype for MSVC
1 parent e3b5549 commit 663eea1

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

exllamav2/exllamav2_ext/ext_rope.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,50 +58,50 @@ void rope_
5858
);
5959
}
6060

61-
long gen_mrope_pos_ids
61+
int64_t gen_mrope_pos_ids
6262
(
6363
torch::Tensor mrope_pos_ids,
6464
torch::Tensor ids,
6565
int merge_size,
66-
const std::vector<std::tuple<long, long>> &spans,
67-
const std::vector<std::tuple<long, long, long>> &grids
66+
const std::vector<std::tuple<int64_t, int64_t>> &spans,
67+
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
6868
)
6969
{
7070
int max_length = mrope_pos_ids.size(1);
7171
int in_length = ids.size(0);
7272

73-
long* in_ids = (long*) ids.data_ptr();
74-
long* pos_ids = (long*) mrope_pos_ids.data_ptr();
73+
int64_t* in_ids = (int64_t*) ids.data_ptr();
74+
int64_t* pos_ids = (int64_t*) mrope_pos_ids.data_ptr();
7575

76-
long* out_t = pos_ids;
77-
long* out_h = pos_ids + max_length;
78-
long* out_w = pos_ids + 2 * max_length;
76+
int64_t* out_t = pos_ids;
77+
int64_t* out_h = pos_ids + max_length;
78+
int64_t* out_w = pos_ids + 2 * max_length;
7979

80-
long base_t = 0;
81-
long next_base_t = 0;
80+
int64_t base_t = 0;
81+
int64_t next_base_t = 0;
8282

8383
for (int i = 0; i < max_length; ++i)
8484
{
8585
bool is_emb = false;
8686
if (i < in_length)
8787
{
88-
long id = in_ids[i];
88+
int64_t id = in_ids[i];
8989

9090
for (int j = 0; j < spans.size(); ++j)
9191
{
92-
long span_start = std::get<0>(spans[j]);
93-
long span_end = std::get<1>(spans[j]);
94-
long span = span_end - span_start;
92+
int64_t span_start = std::get<0>(spans[j]);
93+
int64_t span_end = std::get<1>(spans[j]);
94+
int64_t span = span_end - span_start;
9595
if (id >= span_start && id < span_end)
9696
{
9797
is_emb = true;
98-
long k = id - span_start;
99-
long grid_t = std::get<0>(grids[j]);
100-
long grid_h = std::get<1>(grids[j]) / (long)merge_size;
101-
long grid_w = std::get<2>(grids[j]) / (long)merge_size;
102-
long k_t = base_t + (k / grid_w / grid_h) % grid_t;
103-
long k_h = base_t + (k / grid_w) % grid_h;
104-
long k_w = base_t + k % grid_w;
98+
int64_t k = id - span_start;
99+
int64_t grid_t = std::get<0>(grids[j]);
100+
int64_t grid_h = std::get<1>(grids[j]) / (int64_t)merge_size;
101+
int64_t grid_w = std::get<2>(grids[j]) / (int64_t)merge_size;
102+
int64_t k_t = base_t + (k / grid_w / grid_h) % grid_t;
103+
int64_t k_h = base_t + (k / grid_w) % grid_h;
104+
int64_t k_w = base_t + k % grid_w;
105105
*out_t++ = k_t;
106106
*out_h++ = k_h;
107107
*out_w++ = k_w;

exllamav2/exllamav2_ext/ext_rope.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ void rope_
1111
bool neox_style
1212
);
1313

14-
long gen_mrope_pos_ids
14+
int64_t gen_mrope_pos_ids
1515
(
1616
torch::Tensor mrope_pos_ids,
1717
torch::Tensor ids,
1818
int merge_size,
19-
const std::vector<std::tuple<long, long>> &spans,
20-
const std::vector<std::tuple<long, long, long>> &grids
19+
const std::vector<std::tuple<int64_t, int64_t>> &spans,
20+
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
2121
);

0 commit comments

Comments
 (0)