@@ -477,6 +477,14 @@ static txe_compute_pipeline_state_s tsi_kernel_setup(enum ggml_tsavorite_kernel_
477477 kernel_pipeline->kernel_name = " TXE_RMS_NORM" ;
478478 flag = true ;
479479 break ;
480+ case GGML_TSAVORITE_KERNEL_TYPE_SWIGLU:
481+ {
482+ kernel_pipeline->_mlir_fptr_2_input [DATA_TYPE_F32_INDEX] = &_mlir_ciface_txe_swiglu_host;
483+ kernel_pipeline->_mlir_fptr_2_input [DATA_TYPE_F16_INDEX] = &_mlir_ciface_txe_swiglu_16_host;
484+ kernel_pipeline->kernel_name = " TXE_SWI_GLU" ;
485+ flag = true ;
486+ break ;
487+ }
480488 default :
481489 break ;
482490 }
@@ -625,6 +633,7 @@ static struct ggml_backend_tsavorite_context *ggml_tsavorite_init(ggml_backend_d
625633 GGML_TSAVORITE_KERNEL (GGML_TSAVORITE_KERNEL_TYPE_SIGMOID, true );
626634 GGML_TSAVORITE_KERNEL (GGML_TSAVORITE_KERNEL_TYPE_SILU, true );
627635 GGML_TSAVORITE_KERNEL (GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM, true );
636+ GGML_TSAVORITE_KERNEL (GGML_TSAVORITE_KERNEL_TYPE_SWIGLU, true );
628637 }
629638
630639 GGML_TSAVORITE_LOG_INFO (" End %s\n " , __func__);
@@ -704,7 +713,7 @@ static ggml_backend_tsavorite_buffer_s ggml_tsavorite_get_buffer(struct ggml_ten
704713 return tsi_nil;
705714}
706715#endif
707- bool is_op_dtype_consistent_with_src (const struct ggml_tensor *op) {
716+ static bool is_op_dtype_consistent_with_src (const struct ggml_tensor *op) {
708717 uint32_t tensor_data_type = op->type ;
709718 for (size_t i = 0 ; i < GGML_MAX_DIMS; ++i) {
710719 if (op->src [i] != NULL ) {
@@ -720,16 +729,13 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic
720729 GGML_TSAVORITE_LOG_INFO (" Start %s\n " , __func__);
721730 if (!ctx_dev)
722731 return false ;
723- for (size_t i = 0 , n = 3 ; i < n; ++i) {
724- if (op->src [i] != NULL && op->src [i]->type != GGML_TYPE_F32) {
725- return false ;
726- }
727- }
728732
729- if (op->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F16)
733+ if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16)
730734 return false ;
735+
731736 if (!is_op_dtype_consistent_with_src (op))
732737 return false ;
738+
733739 switch (op->op ) {
734740 case GGML_OP_NONE:
735741 case GGML_OP_ADD:
@@ -739,9 +745,15 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic
739745 case GGML_OP_SQRT:
740746 case GGML_OP_SQR:
741747 case GGML_OP_SIN:
742- break ;
743748 case GGML_OP_RMS_NORM:
744749 break ;
750+ case GGML_OP_GLU:
751+ {
752+ const ggml_glu_op op_ext = ggml_get_glu_op (op);
753+ if (op_ext != GGML_GLU_OP_SWIGLU)
754+ return false ;
755+ break ;
756+ }
745757 case GGML_OP_UNARY:
746758 switch (ggml_get_unary_op (op)) {
747759 case GGML_UNARY_OP_NEG:
@@ -815,6 +827,36 @@ static MemRefDescriptor<Rank>* create_mlir_buf(int K) {
815827 return header;
816828}
817829
830+ static enum ggml_tsavorite_kernel_type tsi_glu_kernel_type (struct ggml_tensor *node) {
831+ const ggml_glu_op op = ggml_get_glu_op (node);
832+ enum ggml_tsavorite_kernel_type kernel_type;
833+
834+ switch (op) {
835+ case GGML_GLU_OP_REGLU:
836+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_REGLU;
837+ break ;
838+ case GGML_GLU_OP_GEGLU:
839+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU;
840+ break ;
841+ case GGML_GLU_OP_SWIGLU:
842+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_SWIGLU;
843+ break ;
844+ case GGML_GLU_OP_SWIGLU_OAI:
845+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_SWIGLU_OAI;
846+ break ;
847+ case GGML_GLU_OP_GEGLU_ERF:
848+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU_ERF;
849+ break ;
850+ case GGML_GLU_OP_GEGLU_QUICK:
851+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_GEGLU_QUICK;
852+ break ;
853+ default :
854+ kernel_type = GGML_TSAVORITE_KERNEL_TYPE_COUNT;
855+ break ;
856+ }
857+ return kernel_type;
858+ }
859+
818860// nodes are intermediate which has multiple src tensors & operation
819861// Here we create multiple thread
820862// Each Thread run the command buffer & pick Tensor and execute and get the result back base on
@@ -940,6 +982,16 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend,
940982 kernel_type = GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM;
941983 num_of_input_tensors = TSAVORITE_UNARY_INPUT_TENSORS;
942984 break ;
985+ case GGML_OP_GLU:
986+ kernel_type = tsi_glu_kernel_type (node);
987+ if (!src1)
988+ src1 = src0;
989+ if (kernel_type == GGML_TSAVORITE_KERNEL_TYPE_COUNT) {
990+ GGML_TSAVORITE_LOG_ERROR (" \n GGML_OP_GLU sub type is not correct \n " );
991+ return GGML_STATUS_ABORTED;
992+ }
993+ num_of_input_tensors = TSAVORITE_TWO_INPUT_TENSORS;
994+ break ;
943995 case GGML_OP_UNARY:
944996 switch (ggml_get_unary_op (node)) {
945997 case GGML_UNARY_OP_NEG:
@@ -1916,10 +1968,12 @@ static bool ggml_backend_tsavorite_device_supports_buft(ggml_backend_dev_t dev,
19161968// ggml_backend_sched_backend_id_from_cur -> ggml_backend_offload_op ->
19171969static bool ggml_backend_tsavorite_device_offload_op (ggml_backend_dev_t dev,
19181970 const struct ggml_tensor *op) {
1919- if (op->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F16)
1971+ if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16)
19201972 return false ;
1973+
19211974 if (!is_op_dtype_consistent_with_src (op))
19221975 return false ;
1976+
19231977 switch (op->op ) {
19241978 case GGML_OP_NONE:
19251979 case GGML_OP_ADD:
@@ -1931,6 +1985,13 @@ static bool ggml_backend_tsavorite_device_offload_op(ggml_backend_dev_t dev,
19311985 case GGML_OP_SIN:
19321986 case GGML_OP_RMS_NORM:
19331987 break ;
1988+ case GGML_OP_GLU:
1989+ {
1990+ const ggml_glu_op op_ext = ggml_get_glu_op (op);
1991+ if (op_ext != GGML_GLU_OP_SWIGLU)
1992+ return false ;
1993+ break ;
1994+ }
19341995 case GGML_OP_UNARY:
19351996 switch (ggml_get_unary_op (op)) {
19361997 case GGML_UNARY_OP_NEG:
0 commit comments