@@ -1505,6 +1505,10 @@ struct llama_context {
15051505
15061506 // decode output (2-dimensional array: [n_tokens][n_vocab])
15071507 std::vector<float > logits;
1508+ #ifndef NDEBUG
1509+ // guard against access to unset logits
1510+ std::vector<bool > logits_valid;
1511+ #endif
15081512 bool logits_all = false ;
15091513
15101514 // input embedding (1-dimensional array: [n_embd])
@@ -6150,20 +6154,37 @@ static int llama_decode_internal(
61506154 {
61516155 auto & logits_out = lctx.logits ;
61526156
6157+ #ifndef NDEBUG
6158+ auto & logits_valid = lctx.logits_valid ;
6159+ logits_valid.clear ();
6160+ logits_valid.resize (n_tokens);
6161+
6162+ logits_out.clear ();
6163+ #endif
6164+
61536165 if (batch.logits ) {
61546166 logits_out.resize (n_vocab * n_tokens);
61556167 for (uint32_t i = 0 ; i < n_tokens; i++) {
61566168 if (batch.logits [i] == 0 ) {
61576169 continue ;
61586170 }
61596171 memcpy (logits_out.data () + (n_vocab*i), (float *) ggml_get_data (res) + (n_vocab*i), sizeof (float )*n_vocab);
6172+ #ifndef NDEBUG
6173+ logits_valid[i] = true ;
6174+ #endif
61606175 }
61616176 } else if (lctx.logits_all ) {
61626177 logits_out.resize (n_vocab * n_tokens);
61636178 memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab*n_tokens);
6179+ #ifndef NDEBUG
6180+ std::fill (logits_valid.begin (), logits_valid.end (), true );
6181+ #endif
61646182 } else {
61656183 logits_out.resize (n_vocab);
61666184 memcpy (logits_out.data (), (float *) ggml_get_data (res) + (n_vocab*(n_tokens - 1 )), sizeof (float )*n_vocab);
6185+ #ifndef NDEBUG
6186+ logits_valid[n_tokens - 1 ] = true ;
6187+ #endif
61676188 }
61686189 }
61696190
@@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
1005210073}
1005310074
1005410075float * llama_get_logits_ith (struct llama_context * ctx, int32_t i) {
10076+ assert (ctx->logits_valid .at (i));
1005510077 return ctx->logits .data () + i*ctx->model .hparams .n_vocab ;
1005610078}
1005710079
0 commit comments