Skip to content

Commit 8def574

Browse files
authored
Merge pull request #49 from tsisw/FIR-938
@FIR-938 - LLama.cpp-GGML: Enable Support for 4D Tensor Data (Rank 4)
2 parents b3f9c99 + 9240a18 commit 8def574

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

ggml/include/ggml-tsavorite.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ extern void _mlir_ciface_txe_silu_host(void *a, void *res);
165165
extern void ggml_tsi_log_tensor_data(tensor_log log_data);
166166

167167
#define NUM_OF_TXES 1
168-
#define MEM_REF_DESCRIPTOR_RANK 1
168+
// GML supports a maximum tensor rank of 4
169+
#define MEM_REF_DESCRIPTOR_RANK 4
169170

170171
//
171172
// backend API

ggml/src/ggml-tsavorite/ggml-tsavorite.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ static void _mlir_ciface_txe_add_test (void *src0, void *src1, void *res)
337337
srcP1 = (MemRefDescriptor<Rank> *)src1;
338338
nodeP = (MemRefDescriptor<Rank> *)res;
339339

340-
uint32_t count = srcP0->shape[Rank - 1];
340+
// TVU kernels operate using a single dimension for the TVU add operation.
341+
uint32_t count = srcP0->shape[0];
342+
341343
float *s0 = (float*)srcP0->data;
342344
float *s1 = (float*)srcP1->data;
343345
float *n = (float*)nodeP->data;
@@ -360,7 +362,9 @@ static void _mlir_ciface_txe_mult_test (void *src0, void *src1, void *res)
360362
srcP1 = (MemRefDescriptor<Rank> *)src1;
361363
nodeP = (MemRefDescriptor<Rank> *)res;
362364

363-
uint32_t count = srcP0->shape[Rank - 1];
365+
// TVU kernels operate using a single dimension for the TVU mul operation.
366+
uint32_t count = srcP0->shape[0];
367+
364368
float *s0 = (float*)srcP0->data;
365369
float *s1 = (float*)srcP1->data;
366370
float *n = (float*)nodeP->data;
@@ -985,10 +989,13 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend,
985989
float *src0_ptr = (float *)((char *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01);
986990
float *src1_ptr = (float *)((char *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11);
987991

992+
// The following below code operates exclusively on Rank 0
993+
// (i.e., the first dimension) for all blob-related processing.
994+
988995
for (int64_t r = 0; r < nr0; ++r) {
989-
srcP0->shape[Rank - 1] = ne10;
990-
srcP1->shape[Rank - 1] = ne10;
991-
nodeP->shape[Rank - 1] = ne10;
996+
srcP0->shape[0] = ne10;
997+
srcP1->shape[0] = ne10;
998+
nodeP->shape[0] = ne10;
992999
srcP1->data = srcP1->base = (void *)(src1_ptr);
9931000
srcP0->data = srcP0->base = (void *)(src0_ptr + r * ne10);
9941001
nodeP->data = nodeP->base = (void *)(dst_ptr + r * ne10);
@@ -1058,10 +1065,13 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend,
10581065

10591066
srcP0->data = srcP0->base = (void *)((float *)src0->data);
10601067
nodeP->data = nodeP->base = (void *)((float *)node->data);
1061-
srcP0->shape[Rank - 1] = num_elem_src0;
1062-
nodeP->shape[Rank - 1] = num_elem_src0;
1063-
srcP0->strides[Rank - 1] = 0;
1064-
nodeP->strides[Rank - 1] = 0;
1068+
1069+
// The following below code operates exclusively on Rank 0
1070+
// (i.e., the first dimension) for all blob-related processing.
1071+
srcP0->shape[0] = num_elem_src0;
1072+
nodeP->shape[0] = num_elem_src0;
1073+
srcP0->strides[0] = 0;
1074+
nodeP->strides[0] = 0;
10651075
// kernel call
10661076
ctx->kernels[kernel_type].pipeline->_mlir_fptr_1_input(srcP0, nodeP);
10671077
++device->stats.op_run_count[kernel_type].num_of_kernel_call;

0 commit comments

Comments
 (0)