@@ -98,6 +98,11 @@ private GenAiOperationNameIncubatingValues() {}
9898 private static final JsonNodeParser JSON_PARSER = JsonNode .parser ();
9999 private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller ();
100100
101+ // used to approximate input/output token count for Cohere and Mistral AI models,
102+ // which don't provide these values in the response body.
103+ // https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
104+ private static final Double CHARS_PER_TOKEN = 6.0 ;
105+
101106 static boolean isBedrockRuntimeRequest (SdkRequest request ) {
102107 if (request instanceof ConverseRequest ) {
103108 return true ;
@@ -244,8 +249,12 @@ private static Long getMaxTokensInvokeModel(
244249 return null ;
245250 }
246251 count = config .asMap ().get ("max_new_tokens" );
247- } else if (modelId .startsWith ("anthropic.claude" )) {
252+ } else if (modelId .startsWith ("anthropic.claude" )
253+ || modelId .startsWith ("cohere.command" )
254+ || modelId .startsWith ("mistral.mistral" )) {
248255 count = body .asMap ().get ("max_tokens" );
256+ } else if (modelId .startsWith ("meta.llama" )) {
257+ count = body .asMap ().get ("max_gen_len" );
249258 }
250259 if (count != null && count .isNumber ()) {
251260 return count .asNumber ().longValue ();
@@ -300,7 +309,10 @@ private static Double getTemperatureInvokeModel(
300309 return null ;
301310 }
302311 temperature = config .asMap ().get ("temperature" );
303- } else if (modelId .startsWith ("anthropic.claude" )) {
312+ } else if (modelId .startsWith ("anthropic.claude" )
313+ || modelId .startsWith ("meta.llama" )
314+ || modelId .startsWith ("cohere.command" )
315+ || modelId .startsWith ("mistral.mistral" )) {
304316 temperature = body .asMap ().get ("temperature" );
305317 }
306318 if (temperature != null && temperature .isNumber ()) {
@@ -354,8 +366,12 @@ private static Double getToppInvokeModel(
354366 return null ;
355367 }
356368 topP = config .asMap ().get ("topP" );
357- } else if (modelId .startsWith ("anthropic.claude" )) {
369+ } else if (modelId .startsWith ("anthropic.claude" )
370+ || modelId .startsWith ("meta.llama" )
371+ || modelId .startsWith ("mistral.mistral" )) {
358372 topP = body .asMap ().get ("top_p" );
373+ } else if (modelId .startsWith ("cohere.command" )) {
374+ topP = body .asMap ().get ("p" );
359375 }
360376 if (topP != null && topP .isNumber ()) {
361377 return topP .asNumber ().doubleValue ();
@@ -409,9 +425,12 @@ private static List<String> getStopSequences(
409425 return null ;
410426 }
411427 stopSequences = config .asMap ().get ("stopSequences" );
412- } else if (modelId .startsWith ("anthropic.claude" )) {
428+ } else if (modelId .startsWith ("anthropic.claude" ) || modelId . startsWith ( "cohere.command" ) ) {
413429 stopSequences = body .asMap ().get ("stop_sequences" );
430+ } else if (modelId .startsWith ("mistral.mistral" )) {
431+ stopSequences = body .asMap ().get ("stop" );
414432 }
433+ // meta llama request does not support stop sequences
415434 if (stopSequences != null && stopSequences .isList ()) {
416435 return stopSequences .asList ().stream ()
417436 .filter (Document ::isString )
@@ -474,8 +493,38 @@ private static List<String> getStopReasons(
474493 Document stopReason = null ;
475494 if (modelId .startsWith ("amazon.nova" )) {
476495 stopReason = body .asMap ().get ("stopReason" );
477- } else if (modelId .startsWith ("anthropic.claude" )) {
496+ } else if (modelId .startsWith ("anthropic.claude" ) || modelId . startsWith ( "meta.llama" ) ) {
478497 stopReason = body .asMap ().get ("stop_reason" );
498+ } else if (modelId .startsWith ("cohere.command-r" )) {
499+ stopReason = body .asMap ().get ("finish_reason" );
500+ } else if (modelId .startsWith ("cohere.command" )) {
501+ List <String > stopReasons = new ArrayList <>();
502+ Document results = body .asMap ().get ("generations" );
503+ if (results == null || !results .isList ()) {
504+ return null ;
505+ }
506+ for (Document result : results .asList ()) {
507+ stopReason = result .asMap ().get ("finish_reason" );
508+ if (stopReason == null || !stopReason .isString ()) {
509+ continue ;
510+ }
511+ stopReasons .add (stopReason .asString ());
512+ }
513+ return stopReasons ;
514+ } else if (modelId .startsWith ("mistral.mistral" )) {
515+ List <String > stopReasons = new ArrayList <>();
516+ Document results = body .asMap ().get ("outputs" );
517+ if (results == null || !results .isList ()) {
518+ return null ;
519+ }
520+ for (Document result : results .asList ()) {
521+ stopReason = result .asMap ().get ("stop_reason" );
522+ if (stopReason == null || !stopReason .isString ()) {
523+ continue ;
524+ }
525+ stopReasons .add (stopReason .asString ());
526+ }
527+ return stopReasons ;
479528 }
480529 if (stopReason != null && stopReason .isString ()) {
481530 return Collections .singletonList (stopReason .asString ());
@@ -534,6 +583,30 @@ private static Long getUsageInputTokens(
534583 return null ;
535584 }
536585 count = usage .asMap ().get ("input_tokens" );
586+ } else if (modelId .startsWith ("meta.llama" )) {
587+ count = body .asMap ().get ("prompt_token_count" );
588+ } else if (modelId .startsWith ("cohere.command-r" )) {
589+ // approximate input tokens based on prompt length
590+ Document requestBody = executionAttributes .getAttribute (INVOKE_MODEL_REQUEST_BODY );
591+ if (requestBody == null || !requestBody .isMap ()) {
592+ return null ;
593+ }
594+ String prompt = requestBody .asMap ().get ("message" ).asString ();
595+ if (prompt == null ) {
596+ return null ;
597+ }
598+ count = Document .fromNumber (Math .ceil (prompt .length () / CHARS_PER_TOKEN ));
599+ } else if (modelId .startsWith ("cohere.command" ) || modelId .startsWith ("mistral.mistral" )) {
600+ // approximate input tokens based on prompt length
601+ Document requestBody = executionAttributes .getAttribute (INVOKE_MODEL_REQUEST_BODY );
602+ if (requestBody == null || !requestBody .isMap ()) {
603+ return null ;
604+ }
605+ String prompt = requestBody .asMap ().get ("prompt" ).asString ();
606+ if (prompt == null ) {
607+ return null ;
608+ }
609+ count = Document .fromNumber (Math .ceil (prompt .length () / CHARS_PER_TOKEN ));
537610 }
538611 if (count != null && count .isNumber ()) {
539612 return count .asNumber ().longValue ();
@@ -604,6 +677,42 @@ private static Long getUsageOutputTokens(
604677 return null ;
605678 }
606679 count = usage .asMap ().get ("output_tokens" );
680+ } else if (modelId .startsWith ("meta.llama" )) {
681+ count = body .asMap ().get ("generation_token_count" );
682+ } else if (modelId .startsWith ("cohere.command-r" )) {
683+ Document text = body .asMap ().get ("text" );
684+ if (text == null || !text .isString ()) {
685+ return null ;
686+ }
687+ count = Document .fromNumber (Math .ceil (text .asString ().length () / CHARS_PER_TOKEN ));
688+ } else if (modelId .startsWith ("cohere.command" )) {
689+ Document generations = body .asMap ().get ("generations" );
690+ if (generations == null || !generations .isList ()) {
691+ return null ;
692+ }
693+ long outputLength = 0 ;
694+ for (Document generation : generations .asList ()) {
695+ Document text = generation .asMap ().get ("text" );
696+ if (text == null || !text .isString ()) {
697+ continue ;
698+ }
699+ outputLength += text .asString ().length ();
700+ }
701+ count = Document .fromNumber (Math .ceil (outputLength / CHARS_PER_TOKEN ));
702+ } else if (modelId .startsWith ("mistral.mistral" )) {
703+ Document outputs = body .asMap ().get ("outputs" );
704+ if (outputs == null || !outputs .isList ()) {
705+ return null ;
706+ }
707+ long outputLength = 0 ;
708+ for (Document output : outputs .asList ()) {
709+ Document text = output .asMap ().get ("text" );
710+ if (text == null || !text .isString ()) {
711+ continue ;
712+ }
713+ outputLength += text .asString ().length ();
714+ }
715+ count = Document .fromNumber (Math .ceil (outputLength / CHARS_PER_TOKEN ));
607716 }
608717 if (count != null && count .isNumber ()) {
609718 return count .asNumber ().longValue ();
0 commit comments