@@ -12,51 +12,46 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1212 const bool src0_is_quantized = (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16);
1313 const bool src1_is_quantized = (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16);
1414
15- // if (src0_is_quantized || src1_is_quantized) {
16- // printf("DEBUG: OUT_PROD with quantized tensors - src0_quantized=%d, src1_quantized=%d\n",
17- // src0_is_quantized, src1_is_quantized);
18- // fflush(stdout);
19- // }
20-
21- // GGML_ASSERT(src0->type == GGML_TYPE_F32);
22- // GGML_ASSERT(src1->type == GGML_TYPE_F32);
23-
2415 GGML_ASSERT (dst->type == GGML_TYPE_F32);
2516
17+ cudaStream_t stream = ctx.stream ();
18+ ggml_cuda_pool & pool = ctx.pool ();
19+
2620 // temp buffers
2721 float * src0_f32 = nullptr ;
2822 float * src1_f32 = nullptr ;
2923 bool allocated_src0 = false ;
3024 bool allocated_src1 = false ;
31- cudaStream_t stream = ctx.stream ();
25+ ggml_cuda_pool_alloc<float > src0_alloc (pool);
26+ ggml_cuda_pool_alloc<float > src1_alloc (pool);
3227
3328 if (src0_is_quantized) {
34- const size_t src0_size = ggml_nelements (src0) * sizeof (float );
35- CUDA_CHECK (cudaMallocAsync (&src0_f32, src0_size, stream));
29+ const size_t src0_size = ggml_nelements (src0);
30+ src0_alloc.alloc (src0_size);
31+ src0_f32 = src0_alloc.ptr ;
3632 allocated_src0 = true ;
3733
3834 // Dequantize
3935 auto dequantize_fn = ggml_get_to_fp32_cuda (src0->type );
4036 if (dequantize_fn) {
4137 dequantize_fn (src0->data , src0_f32, ggml_nelements (src0), stream);
4238 } else {
43- CUDA_CHECK (cudaFreeAsync (src0_f32, stream));
4439 GGML_ABORT (" Unsupported quant type for src0" );
4540 }
4641 } else {
4742 src0_f32 = (float *) src0->data ;
4843 }
4944
5045 if (src1_is_quantized) {
51- const size_t src1_size = ggml_nelements (src1) * sizeof (float );
52- CUDA_CHECK (cudaMallocAsync (&src1_f32, src1_size, stream));
46+ const size_t src1_size = ggml_nelements (src1);
47+ src1_alloc.alloc (src1_size);
48+ src1_f32 = src1_alloc.ptr ;
5349 allocated_src1 = true ;
5450
5551 auto dequantize_fn = ggml_get_to_fp32_cuda (src1->type );
5652 if (dequantize_fn) {
57- dequantize_fn (src1->data , src1_f32, ggml_nelements (src0 ), stream);
53+ dequantize_fn (src1->data , src1_f32, ggml_nelements (src1 ), stream);
5854 } else {
59- CUDA_CHECK (cudaFreeAsync (src1_f32, stream));
6055 GGML_ABORT (" Unsupported quant type for src1" );
6156 }
6257 } else {
@@ -74,9 +69,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
7469 GGML_ASSERT (ne2 == src1->ne [2 ]);
7570 GGML_ASSERT (ne3 == src1->ne [3 ]);
7671
77- // const float * src0_d = (const float *) src0->data;
78- // const float * src1_d = (const float *) src1->data;
79-
8072 // Use dequantized data
8173 const float * src0_d = src0_f32;
8274 const float * src1_d = src1_f32;
@@ -89,28 +81,21 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
8981
9082 CUBLAS_CHECK (cublasSetStream (handle, stream));
9183
92- // const int64_t lda = nb01 / sizeof(float);
9384 const int64_t lda = allocated_src0 ? ne00 : (nb01 / sizeof (float ));
9485 const int64_t ldc = nb1 / sizeof (float );
9586
9687 const bool src1_T = ggml_is_transposed (src1);
9788 const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
98- // const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
9989 const int64_t ldb = allocated_src1 ?
10090 (src1_T ? ne10 : ne11) :
10191 ((src1_T ? nb10 : nb11) / sizeof (float ));
10292
103- // GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
10493 // Only assert for non dequantized src1
10594 if (!allocated_src1) {
10695 GGML_ASSERT ((src1_T ? nb11 : nb10) == sizeof (float ));
10796 }
10897
10998 // data strides in dimensions 2/3
110- // const size_t s02 = nb02 / sizeof(float);
111- // const size_t s03 = nb03 / sizeof(float);
112- // const size_t s12 = nb12 / sizeof(float);
113- // const size_t s13 = nb13 / sizeof(float);
11499 const size_t s02 = allocated_src0 ? (ne00 * ne01) : nb02 / sizeof (float );
115100 const size_t s03 = allocated_src0 ? (ne00 * ne01 * ne02): nb03 / sizeof (float );
116101 const size_t s12 = allocated_src1 ? (ne10 * ne11) : nb12 / sizeof (float );
@@ -134,15 +119,4 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
134119 }
135120 }
136121
137- if (allocated_src0) {
138- CUDA_CHECK (cudaFreeAsync (src0_f32, stream));
139- // printf("DEBUG: Freed dequantized src0 buffer\n");
140- }
141- if (allocated_src1) {
142- CUDA_CHECK (cudaFreeAsync (src1_f32, stream));
143- // // printf("DEBUG: Freed dequantized src1 buffer\n");
144- }
145-
146- // printf("DEBUG: CUDA OUT_PROD completed successfully\n");
147- fflush (stdout);
148122}
0 commit comments