Skip to content

Commit c806b46

Browse files
authored
Merge pull request #59 from tsisw/FIR-998
@FIR-998 - Create GLU/SWIGLU Support for posix and fpga
2 parents ffcc5ca + fbaebf9 commit c806b46

File tree

4 files changed

+84
-11
lines changed

4 files changed

+84
-11
lines changed

ggml-tsi-kernel

ggml/include/ggml-tsavorite.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ enum ggml_tsavorite_kernel_type {
129129
GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM,
130130
GGML_TSAVORITE_KERNEL_TYPE_SIGMOID,
131131
GGML_TSAVORITE_KERNEL_TYPE_SILU,
132+
//Below GELU Kernel
133+
GGML_TSAVORITE_KERNEL_TYPE_REGLU,
134+
GGML_TSAVORITE_KERNEL_TYPE_GEGLU,
135+
136+
// Currently Below kernel Implemented
137+
GGML_TSAVORITE_KERNEL_TYPE_SWIGLU,
138+
139+
GGML_TSAVORITE_KERNEL_TYPE_SWIGLU_OAI,
140+
GGML_TSAVORITE_KERNEL_TYPE_GEGLU_ERF,
141+
GGML_TSAVORITE_KERNEL_TYPE_GEGLU_QUICK,
132142

133143
GGML_TSAVORITE_KERNEL_TYPE_COUNT
134144
};
@@ -174,6 +184,7 @@ extern void _mlir_ciface_txe_abs_host(void *a, void *res);
174184
extern void _mlir_ciface_txe_sin_host(void *a, void *res);
175185
extern void _mlir_ciface_txe_sigmoid_host(void *a, void *res);
176186
extern void _mlir_ciface_txe_silu_host(void *a, void *res);
187+
extern void _mlir_ciface_txe_swiglu_host(void *a, void *b, void *res);
177188
extern void _mlir_ciface_txe_rms_norm_host(void *a, void *res, void *buf);
178189

179190
/*
@@ -190,6 +201,7 @@ extern void _mlir_ciface_txe_abs_16_host(void *a, void *res);
190201
extern void _mlir_ciface_txe_sin_16_host(void *a, void *res);
191202
extern void _mlir_ciface_txe_sigmoid_16_host(void *a, void *res);
192203
extern void _mlir_ciface_txe_silu_16_host(void *a, void *res);
204+
extern void _mlir_ciface_txe_swiglu_16_host(void *a, void *b, void *res);
193205
extern void _mlir_ciface_txe_rms_norm_16_host(void *a, void *res, void *buf);
194206

195207
extern void ggml_tsi_log_tensor_data(tensor_log log_data);

ggml/src/ggml-tsavorite/ggml-tsavorite.cpp

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ static txe_compute_pipeline_state_s tsi_kernel_setup(enum ggml_tsavorite_kernel_
477477
kernel_pipeline->kernel_name = "TXE_RMS_NORM";
478478
flag = true;
479479
break;
480+
case GGML_TSAVORITE_KERNEL_TYPE_SWIGLU:
481+
{
482+
kernel_pipeline->_mlir_fptr_2_input[DATA_TYPE_F32_INDEX] = &_mlir_ciface_txe_swiglu_host;
483+
kernel_pipeline->_mlir_fptr_2_input[DATA_TYPE_F16_INDEX] = &_mlir_ciface_txe_swiglu_16_host;
484+
kernel_pipeline->kernel_name = "TXE_SWI_GLU";
485+
flag = true;
486+
break;
487+
}
480488
default:
481489
break;
482490
}
@@ -625,6 +633,7 @@ static struct ggml_backend_tsavorite_context *ggml_tsavorite_init(ggml_backend_d
625633
GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SIGMOID, true);
626634
GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SILU, true);
627635
GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM, true);
636+
GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SWIGLU, true);
628637
}
629638

630639
GGML_TSAVORITE_LOG_INFO("End %s\n", __func__);
@@ -704,7 +713,7 @@ static ggml_backend_tsavorite_buffer_s ggml_tsavorite_get_buffer(struct ggml_ten
704713
return tsi_nil;
705714
}
706715
#endif
707-
bool is_op_dtype_consistent_with_src(const struct ggml_tensor *op) {
716+
static bool is_op_dtype_consistent_with_src(const struct ggml_tensor *op) {
708717
uint32_t tensor_data_type = op->type;
709718
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
710719
if (op->src[i] != NULL) {
@@ -720,16 +729,13 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic
720729
GGML_TSAVORITE_LOG_INFO("Start %s\n", __func__);
721730
if (!ctx_dev)
722731
return false;
723-
for (size_t i = 0, n = 3; i < n; ++i) {
724-
if (op->src[i] != NULL && op->src[i]->type != GGML_TYPE_F32) {
725-
return false;
726-
}
727-
}
728732

729-
if (op->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F16)
733+
if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16)
730734
return false;
735+
731736
if (!is_op_dtype_consistent_with_src(op))
732737
return false;
738+
733739
switch (op->op) {
734740
case GGML_OP_NONE:
735741
case GGML_OP_ADD:
@@ -739,9 +745,15 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic
739745
case GGML_OP_SQRT:
740746
case GGML_OP_SQR:
741747
case GGML_OP_SIN:
742-
break;
743748
case GGML_OP_RMS_NORM:
744749
break;
750+
case GGML_OP_GLU:
751+
{
752+
const ggml_glu_op op_ext = ggml_get_glu_op(op);
753+
if (op_ext != GGML_GLU_OP_SWIGLU)
754+
return false;
755+
break;
756+
}
745757
case GGML_OP_UNARY:
746758
switch (ggml_get_unary_op(op)) {
747759
case GGML_UNARY_OP_NEG:
@@ -815,6 +827,36 @@ static MemRefDescriptor<Rank>* create_mlir_buf(int K) {
815827
return header;
816828
}
817829

830+
static enum ggml_tsavorite_kernel_type tsi_glu_kernel_type(struct ggml_tensor *node) {
831+
const ggml_glu_op op = ggml_get_glu_op(node);
832+
enum ggml_tsavorite_kernel_type kernel_type;
833+
834+
switch (op) {
835+
case GGML_GLU_OP_REGLU:
836+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_REGLU;
837+
break;
838+
case GGML_GLU_OP_GEGLU:
839+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU;
840+
break;
841+
case GGML_GLU_OP_SWIGLU:
842+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_SWIGLU;
843+
break;
844+
case GGML_GLU_OP_SWIGLU_OAI:
845+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_SWIGLU_OAI;
846+
break;
847+
case GGML_GLU_OP_GEGLU_ERF:
848+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU_ERF;
849+
break;
850+
case GGML_GLU_OP_GEGLU_QUICK:
851+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU_QUICK;
852+
break;
853+
default:
854+
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_COUNT;
855+
break;
856+
}
857+
return kernel_type;
858+
}
859+
818860
// nodes are intermediate which has multiple src tensors & operation
819861
// Here we create multiple thread
820862
// Each Thread run the command buffer & pick Tensor and execute and get the result back base on
@@ -940,6 +982,16 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend,
940982
kernel_type = GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM;
941983
num_of_input_tensors = TSAVORITE_UNARY_INPUT_TENSORS;
942984
break;
985+
case GGML_OP_GLU:
986+
kernel_type = tsi_glu_kernel_type(node);
987+
if (!src1)
988+
src1 = src0;
989+
if (kernel_type == GGML_TSAVORITE_KERNEL_TYPE_COUNT) {
990+
GGML_TSAVORITE_LOG_ERROR("\n GGML_OP_GLU sub type is not correct \n");
991+
return GGML_STATUS_ABORTED;
992+
}
993+
num_of_input_tensors = TSAVORITE_TWO_INPUT_TENSORS;
994+
break;
943995
case GGML_OP_UNARY:
944996
switch (ggml_get_unary_op(node)) {
945997
case GGML_UNARY_OP_NEG:
@@ -1916,10 +1968,12 @@ static bool ggml_backend_tsavorite_device_supports_buft(ggml_backend_dev_t dev,
19161968
// ggml_backend_sched_backend_id_from_cur -> ggml_backend_offload_op ->
19171969
static bool ggml_backend_tsavorite_device_offload_op(ggml_backend_dev_t dev,
19181970
const struct ggml_tensor *op) {
1919-
if (op->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F16)
1971+
if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16)
19201972
return false;
1973+
19211974
if (!is_op_dtype_consistent_with_src(op))
19221975
return false;
1976+
19231977
switch (op->op) {
19241978
case GGML_OP_NONE:
19251979
case GGML_OP_ADD:
@@ -1931,6 +1985,13 @@ static bool ggml_backend_tsavorite_device_offload_op(ggml_backend_dev_t dev,
19311985
case GGML_OP_SIN:
19321986
case GGML_OP_RMS_NORM:
19331987
break;
1988+
case GGML_OP_GLU:
1989+
{
1990+
const ggml_glu_op op_ext = ggml_get_glu_op(op);
1991+
if (op_ext != GGML_GLU_OP_SWIGLU)
1992+
return false;
1993+
break;
1994+
}
19341995
case GGML_OP_UNARY:
19351996
switch (ggml_get_unary_op(op)) {
19361997
case GGML_UNARY_OP_NEG:

tsi-pkg-build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ cat > ./${TSI_GGML_BUNDLE_INSTALL_DIR}/ggml.sh << EOL
9090
# Set up library paths for GCC 13.3.0 compatibility
9191
export LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:\$(pwd)
9292
93-
tsi_kernels=("add" "sub" "mult" "div" "abs" "inv" "neg" "sin" "sqrt" "sqr" "sigmoid" "silu" "rms_norm" "add_16" "sub_16" "mult_16" "div_16" "abs_16" "inv_16" "neg_16" "sin_16" "sqrt_16" "sqr_16" "sigmoid_16" "silu_16" "rms_norm_16")
93+
tsi_kernels=("add" "sub" "mult" "div" "abs" "inv" "neg" "sin" "sqrt" "sqr" "sigmoid" "silu" "rms_norm" "swiglu" "add_16" "sub_16" "mult_16" "div_16" "abs_16" "inv_16" "neg_16" "sin_16" "sqrt_16" "sqr_16" "sigmoid_16" "silu_16" "rms_norm_16 swiglu_16")
9494
9595
for kernel in "\${tsi_kernels[@]}"; do
9696
mkdir -p ${TSI_BLOB_INSTALL_DIR}/txe_\$kernel

0 commit comments

Comments
 (0)