Skip to content

Commit 4e45d31

Browse files
fix wrong index compute in f32x4 and f16x8 (#133)
wrong index compute in f32x4 and f16x8
1 parent f75d8f6 commit 4e45d31

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

embedding/embedding.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ __global__ void embedding_f32_kernel(const int *idx, float *weight, float *outpu
2424

2525
__global__ void embedding_f32x4_kernel(const int *idx, float *weight, float *output, int n, int emb_size)
2626
{
27-
int tx = threadIdx.x;
27+
int tx = threadIdx.x * 4;
2828
int bx = blockIdx.x;
29-
int tid = bx * blockDim.x + tx;
3029
int offset = idx[bx] * emb_size;
3130
output[bx * emb_size + tx] = weight[offset + tx];
3231
output[bx * emb_size + tx + 1] = weight[offset + tx + 1];
@@ -54,9 +53,8 @@ __global__ void embedding_f16_kernel(const int *idx, half *weight, half *output,
5453

5554
__global__ void embedding_f16x8_kernel(const int *idx, half *weight, half *output, int n, int emb_size)
5655
{
57-
int tx = threadIdx.x;
56+
int tx = threadIdx.x * 8;
5857
int bx = blockIdx.x;
59-
int tid = bx * blockDim.x + tx;
6058
int offset = idx[bx] * emb_size;
6159
output[bx * emb_size + tx] = weight[offset + tx];
6260
output[bx * emb_size + tx + 1] = weight[offset + tx + 1];

0 commit comments

Comments
 (0)