@@ -639,6 +639,7 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
639639 }
640640
641641 int32_t inputTokensLength = info[0 ].As <Napi::Number>().Int32Value ();
642+ int32_t maxVectorSize = (info.Length () > 1 && info[1 ].IsNumber ()) ? info[1 ].As <Napi::Number>().Int32Value () : 0 ;
642643
643644 if (inputTokensLength <= 0 ) {
644645 Napi::Error::New (info.Env (), " Invalid input tokens length" ).ThrowAsJavaScriptException ();
@@ -650,15 +651,16 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
650651 const auto * embeddings = pooling_type == LLAMA_POOLING_TYPE_NONE ? NULL : llama_get_embeddings_seq (ctx, 0 );
651652 if (embeddings == NULL ) {
652653 embeddings = llama_get_embeddings_ith (ctx, inputTokensLength - 1 );
654+ }
653655
654- if (embeddings == NULL ) {
655- Napi::Error::New (info.Env (), std::string (" Failed to get embeddings for token " ) + std::to_string (inputTokensLength - 1 )).ThrowAsJavaScriptException ();
656- return info.Env ().Undefined ();
657- }
656+ if (embeddings == NULL ) {
657+ Napi::Error::New (info.Env (), std::string (" Failed to get embeddings for token " ) + std::to_string (inputTokensLength - 1 )).ThrowAsJavaScriptException ();
658+ return info.Env ().Undefined ();
658659 }
659660
660- Napi::Float64Array result = Napi::Float64Array::New (info.Env (), n_embd);
661- for (size_t i = 0 ; i < n_embd; ++i) {
661+ size_t resultSize = maxVectorSize == 0 ? n_embd : std::min (n_embd, maxVectorSize);
662+ Napi::Float64Array result = Napi::Float64Array::New (info.Env (), resultSize);
663+ for (size_t i = 0 ; i < resultSize; i++) {
662664 result[i] = embeddings[i];
663665 }
664666
0 commit comments