Skip to content

Commit 06bcec2

Browse files
committed
Add test for matmul
1 parent 1b10421 commit 06bcec2

File tree

5 files changed

+92
-5
lines changed

5 files changed

+92
-5
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
*.prof
55
test/*.bin
66
test/test_add
7-
test/test_fma
7+
test/test_fma
8+
test/test_matmul

test/Makefile

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,29 @@ CFLAGS = -DTEST
44

55
TEST_ADD_CU = test_add.cu
66
TEST_FMA_CU = test_fma.cu
7-
KERNEL_CU = ../matmulf8_kernel.cu
7+
TEST_MATMUL_CU = test_matmul.cu
8+
KERNEL_CU = ../matmulf8_kernel.cu ../matmulf8.cu ../load_core.cu
89
OUTPUT_ADD = test_add
910
OUTPUT_FMA = test_fma
11+
OUTPUT_MATMUL = test_matmul
1012
BIN_FILE = test_add.bin test_fma.bin
1113

12-
OUTPUT_EXE = $(OUTPUT_ADD) $(OUTPUT_FMA)
14+
OUTPUT_EXE = $(OUTPUT_ADD) $(OUTPUT_FMA) $(OUTPUT_MATMUL)
1315

14-
all: $(OUTPUT_EXE) $(BIN_FILE)
16+
test: $(OUTPUT_EXE) $(BIN_FILE)
17+
./$(OUTPUT_ADD)
18+
./$(OUTPUT_FMA)
19+
./$(OUTPUT_MATMUL)
1520

1621
$(OUTPUT_ADD): $(TEST_ADD_CU) $(KERNEL_CU)
1722
$(NVCC) $(CFLAGS) $(TEST_ADD_CU) $(KERNEL_CU) -o $(OUTPUT_ADD)
1823

1924
$(OUTPUT_FMA): $(TEST_FMA_CU) $(KERNEL_CU)
2025
$(NVCC) $(CFLAGS) $(TEST_FMA_CU) $(KERNEL_CU) -o $(OUTPUT_FMA)
2126

27+
$(OUTPUT_MATMUL): $(TEST_MATMUL_CU) $(KERNEL_CU)
28+
$(NVCC) $(CFLAGS) $(TEST_MATMUL_CU) $(KERNEL_CU) -o $(OUTPUT_MATMUL)
29+
2230
$(BIN_FILE): gen_test_bin.py
2331
python gen_test_bin.py
2432

test/gen_test_bin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ def gen_fma_test_bin():
3434

3535
if __name__ == '__main__':
3636
gen_add_test_bin()
37-
gen_fma_test_bin()
37+
gen_fma_test_bin()
38+
print(float8_e5m2(128.0).tobytes().hex())

test/test_matmul.cu

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include "../load_core.cuh"
2+
#include "../matmulf8.cuh"
3+
#include <cstdio>
4+
#include <random>
5+
#include "test_matmul.h"
6+
7+
void test_cross_line(u_int8_t *A, u_int8_t *B, u_int8_t *C, int n, int m, int p,
8+
int target_x, int target_y,
9+
int *acore, int *mcore){
10+
int success = 1;
11+
for(int i = 0; i < n * m; i++) {
12+
A[i] = 0;
13+
}
14+
for(int i = 0; i < m * p; i++) {
15+
B[i] = 0;
16+
}
17+
for (int i = 0; i < m; i++) {
18+
A[target_x * m + i] = 0x3c;
19+
}
20+
for (int i = 0; i < m; i++) {
21+
// B is transposed
22+
B[target_y * m + i] = 0x3c;
23+
}
24+
for (int i = 0; i < n * p; i++) {
25+
C[i] = 0xff;
26+
}
27+
28+
float t = matmul((int *)A, (int *)B, (int *)C, n, m, p, acore, mcore);
29+
30+
for (int i = 0; i < n; i++) {
31+
for (int j = 0; j < p; j++) {
32+
if (i == target_x && j == target_y){
33+
// 0x58 is float8_e5m2 of 128.0
34+
if(C[i * p + j] != 0x58){
35+
printf("Target(%d, %d) expected 0x58(128.0), got: %x\n", i, j, C[i * p + j]);
36+
success = 0;
37+
}
38+
}else{
39+
if(C[i * p + j] != 0){
40+
printf("(%d, %d) expected 0, got: %x", i, j, C[i * p + j]);
41+
success = 0;
42+
}
43+
}
44+
}
45+
}
46+
if(!success){
47+
printf("Test failed for (%d, %d)\n", target_x, target_y);
48+
}else{
49+
printf("Test passed for (%d, %d)\n", target_x, target_y);
50+
}
51+
}
52+
53+
int main(){
54+
cudaSetDevice(0);
55+
56+
int* acore = load_core("../apdcore.bin");
57+
int* mcore = load_core("../mltcore.bin");
58+
int n = 128, m = 128, p = 128;
59+
u_int8_t *A, *B, *C;
60+
61+
cudaMallocHost(&A, n * m);
62+
cudaMallocHost(&B, m * p);
63+
cudaMallocHost(&C, n * p);
64+
std::mt19937 gen(std::random_device{}());
65+
std::uniform_int_distribution<int> dis(0, 128);
66+
67+
for(int t = 0; t < 10; t++){
68+
int x = dis(gen), y = dis(gen);
69+
test_cross_line(A, B, C, n, m, p, x, y, acore, mcore);
70+
}
71+
72+
cudaFreeHost(A);
73+
cudaFreeHost(B);
74+
cudaFreeHost(C);
75+
}

test/test_matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#pragma once
2+
void test_cross_line(int m, u_int8_t *A, u_int8_t *B, int n, int p, u_int8_t *C, int *acore, int *mcore);

0 commit comments

Comments
 (0)