@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
508508}
509509
510510float * llama_context::get_logits () {
511+ output_reorder ();
512+
511513 return logits;
512514}
513515
514516float * llama_context::get_logits_ith (int32_t i) {
515517 int64_t j = -1 ;
516518
519+ output_reorder ();
520+
517521 try {
518522 if (logits == nullptr ) {
519523 throw std::runtime_error (" no logits" );
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
550554}
551555
552556float * llama_context::get_embeddings () {
557+ output_reorder ();
558+
553559 return embd;
554560}
555561
556562float * llama_context::get_embeddings_ith (int32_t i) {
557563 int64_t j = -1 ;
558564
565+ output_reorder ();
566+
559567 try {
560568 if (embd == nullptr ) {
561569 throw std::runtime_error (" no embeddings" );
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
970978
971979 // TODO: this clear of the buffer can easily be forgotten - need something better
972980 embd_seq.clear ();
981+ output_swaps.clear ();
973982
974983 bool did_optimize = false ;
975984
@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11891198 // make the outputs have the same order they had in the user-provided batch
11901199 // note: this is mostly relevant for recurrent models atm
11911200 if (!sorted_output) {
1192- const uint32_t n_vocab = model.vocab .n_tokens ();
1193- const uint64_t n_embd = model.hparams .n_embd ;
1194-
11951201 GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
11961202
11971203 // TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
12071213 continue ;
12081214 }
12091215 std::swap (out_ids[i], out_ids[j_min]);
1210- if (logits_size > 0 ) {
1211- for (uint32_t k = 0 ; k < n_vocab; k++) {
1212- std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1213- }
1214- }
1215- if (embd_size > 0 ) {
1216- for (uint32_t k = 0 ; k < n_embd; k++) {
1217- std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1218- }
1219- }
1216+
1217+ // remember the swaps and apply them lazily upon logits/embeddings access
1218+ output_swaps.push_back ({ i, j_min });
12201219 }
12211220
12221221 std::fill (output_ids.begin (), output_ids.end (), -1 );
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13071306 return n_outputs_max;
13081307}
13091308
1309+ void llama_context::output_reorder () {
1310+ const uint32_t n_vocab = model.vocab .n_tokens ();
1311+ const uint64_t n_embd = model.hparams .n_embd ;
1312+
1313+ for (uint32_t s = 0 ; s < output_swaps.size (); ++s) {
1314+ const uint32_t i0 = output_swaps[s].i0 ;
1315+ const uint32_t i1 = output_swaps[s].i1 ;
1316+
1317+ if (logits_size > 0 ) {
1318+ for (uint32_t k = 0 ; k < n_vocab; k++) {
1319+ std::swap (logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1320+ }
1321+ }
1322+
1323+ if (embd_size > 0 ) {
1324+ for (uint32_t k = 0 ; k < n_embd; k++) {
1325+ std::swap (embd[i0*n_embd + k], embd[i1*n_embd + k]);
1326+ }
1327+ }
1328+ }
1329+
1330+ output_swaps.clear ();
1331+ }
1332+
13101333//
13111334// graph
13121335//
0 commit comments