@@ -580,6 +580,79 @@ struct decode_embd_batch {
580580 }
581581};
582582
583+ // Helper function for decoding an image whose embeddings have already been calculated
584+ int32_t mtmd_helper_decode_image_chunk (
585+ mtmd_context * ctx,
586+ struct llama_context * lctx,
587+ const mtmd_input_chunk * chunk,
588+ float * encoded_embd,
589+ llama_pos n_past,
590+ llama_seq_id seq_id,
591+ int32_t n_batch,
592+ llama_pos * new_n_past) {
593+ if (mtmd_input_chunk_get_type (chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
594+ LOG_ERR (" failed to decode image chunk: input chunk not of image type\n " );
595+ return -1 ;
596+ }
597+ const auto image_tokens = mtmd_input_chunk_get_tokens_image (chunk);
598+ if (!image_tokens) {
599+ LOG_ERR (" failed to decode image chunk: image tokens are null\n " );
600+ return -1 ;
601+ }
602+
603+ int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
604+ int n_pos_per_embd = mtmd_decode_use_mrope (ctx) ? 4 : 1 ;
605+
606+ int32_t n_tokens = mtmd_image_tokens_get_n_tokens (image_tokens);
607+ int32_t i_batch = 0 ;
608+ int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
609+ decode_embd_batch batch_embd (encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
610+
611+ const int nx = mtmd_image_tokens_get_nx (image_tokens);
612+ const int ny = mtmd_image_tokens_get_ny (image_tokens);
613+
614+ if (mtmd_decode_use_mrope (ctx)) {
615+ batch_embd.set_position_mrope (n_past, nx, ny, seq_id);
616+ } else {
617+ batch_embd.set_position_normal (n_past, seq_id);
618+ }
619+
620+ if (mtmd_decode_use_non_causal (ctx)) {
621+ llama_set_causal_attn (lctx, false );
622+ // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
623+ }
624+
625+ while (i_batch < n_img_batches) { // split into batches
626+ int pos_offset = i_batch*n_batch;
627+ int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
628+ llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
629+
630+ LOG_INF (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
631+
632+ int64_t t1 = ggml_time_ms ();
633+ int32_t ret = llama_decode (lctx, batch_embd_view);
634+ if (ret != 0 ) {
635+ LOG_ERR (" failed to decode image\n " );
636+ llama_set_causal_attn (lctx, true ); // restore causal attn
637+ return ret;
638+ }
639+
640+ if (ctx->print_timings ) {
641+ LOG_INF (" image decoded (batch %d/%d) in %" PRId64 " ms\n " , i_batch+1 , n_img_batches, ggml_time_ms () - t1);
642+ }
643+
644+ i_batch++;
645+ }
646+
647+ n_past += mtmd_image_tokens_get_n_pos (image_tokens);
648+ *new_n_past = n_past;
649+
650+ if (mtmd_decode_use_non_causal (ctx)) {
651+ llama_set_causal_attn (lctx, true );
652+ }
653+ return 0 ;
654+ }
655+
583656int32_t mtmd_helper_eval_chunk_single (mtmd_context * ctx,
584657 struct llama_context * lctx,
585658 const mtmd_input_chunk * chunk,
@@ -591,8 +664,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591664 int32_t ret;
592665 llama_batch text_batch = llama_batch_init (n_batch, 0 , 1 );
593666 auto chunk_type = mtmd_input_chunk_get_type (chunk);
594- int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
595- int n_pos_per_embd = mtmd_decode_use_mrope (ctx) ? 4 : 1 ;
596667
597668 if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
598669 size_t n_tokens;
@@ -637,57 +708,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
637708 if (ctx->print_timings ) {
638709 LOG_INF (" image/slice encoded in %" PRId64 " ms\n " , ggml_time_ms () - t0);
639710 }
640-
641- int32_t n_tokens = mtmd_image_tokens_get_n_tokens (image_tokens);
642- int32_t i_batch = 0 ;
643- int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
644711 float * embd = mtmd_get_output_embd (ctx);
645- decode_embd_batch batch_embd (embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
646-
647- const int nx = mtmd_image_tokens_get_nx (image_tokens);
648- const int ny = mtmd_image_tokens_get_ny (image_tokens);
649-
650- if (mtmd_decode_use_mrope (ctx)) {
651- batch_embd.set_position_mrope (n_past, nx, ny, seq_id);
652- } else {
653- batch_embd.set_position_normal (n_past, seq_id);
654- }
655-
656- if (mtmd_decode_use_non_causal (ctx)) {
657- llama_set_causal_attn (lctx, false );
658- // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
659- }
660-
661- while (i_batch < n_img_batches) { // split into batches
662- int pos_offset = i_batch*n_batch;
663- int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
664- llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
665-
666- LOG_INF (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
667-
668- int64_t t1 = ggml_time_ms ();
669- ret = llama_decode (lctx, batch_embd_view);
670- if (ret != 0 ) {
671- LOG_ERR (" failed to decode image\n " );
672- llama_set_causal_attn (lctx, true ); // restore causal attn
673- llama_batch_free (text_batch);
674- return ret;
675- }
676-
677- if (ctx->print_timings ) {
678- LOG_INF (" image decoded (batch %d/%d) in %" PRId64 " ms\n " , i_batch+1 , n_img_batches, ggml_time_ms () - t1);
679- }
680-
681- i_batch++;
682- }
683-
684- n_past += mtmd_image_tokens_get_n_pos (image_tokens);
685- *new_n_past = n_past;
686-
687- if (mtmd_decode_use_non_causal (ctx)) {
688- llama_set_causal_attn (lctx, true );
712+ ret = mtmd_helper_decode_image_chunk (ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
713+ if (ret != 0 ) {
714+ LOG_ERR (" failed to decode image\n " );
715+ llama_batch_free (text_batch);
716+ return ret;
689717 }
690-
691718 } else {
692719 GGML_ABORT (" chunk type not supported" );
693720 }
0 commit comments