Skip to content

Commit 7561e77

Browse files
authored
Add gen ai support for additional models (open-telemetry#13682)
1 parent abe9b81 commit 7561e77

6 files changed

+487
-5
lines changed

instrumentation/aws-sdk/aws-sdk-2.2/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v2_2/internal/BedrockRuntimeImpl.java

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ private GenAiOperationNameIncubatingValues() {}
9898
private static final JsonNodeParser JSON_PARSER = JsonNode.parser();
9999
private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller();
100100

101+
// used to approximate input/output token count for Cohere and Mistral AI models,
102+
// which don't provide these values in the response body.
103+
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
104+
private static final Double CHARS_PER_TOKEN = 6.0;
105+
101106
static boolean isBedrockRuntimeRequest(SdkRequest request) {
102107
if (request instanceof ConverseRequest) {
103108
return true;
@@ -244,8 +249,12 @@ private static Long getMaxTokensInvokeModel(
244249
return null;
245250
}
246251
count = config.asMap().get("max_new_tokens");
247-
} else if (modelId.startsWith("anthropic.claude")) {
252+
} else if (modelId.startsWith("anthropic.claude")
253+
|| modelId.startsWith("cohere.command")
254+
|| modelId.startsWith("mistral.mistral")) {
248255
count = body.asMap().get("max_tokens");
256+
} else if (modelId.startsWith("meta.llama")) {
257+
count = body.asMap().get("max_gen_len");
249258
}
250259
if (count != null && count.isNumber()) {
251260
return count.asNumber().longValue();
@@ -300,7 +309,10 @@ private static Double getTemperatureInvokeModel(
300309
return null;
301310
}
302311
temperature = config.asMap().get("temperature");
303-
} else if (modelId.startsWith("anthropic.claude")) {
312+
} else if (modelId.startsWith("anthropic.claude")
313+
|| modelId.startsWith("meta.llama")
314+
|| modelId.startsWith("cohere.command")
315+
|| modelId.startsWith("mistral.mistral")) {
304316
temperature = body.asMap().get("temperature");
305317
}
306318
if (temperature != null && temperature.isNumber()) {
@@ -354,8 +366,12 @@ private static Double getToppInvokeModel(
354366
return null;
355367
}
356368
topP = config.asMap().get("topP");
357-
} else if (modelId.startsWith("anthropic.claude")) {
369+
} else if (modelId.startsWith("anthropic.claude")
370+
|| modelId.startsWith("meta.llama")
371+
|| modelId.startsWith("mistral.mistral")) {
358372
topP = body.asMap().get("top_p");
373+
} else if (modelId.startsWith("cohere.command")) {
374+
topP = body.asMap().get("p");
359375
}
360376
if (topP != null && topP.isNumber()) {
361377
return topP.asNumber().doubleValue();
@@ -409,9 +425,12 @@ private static List<String> getStopSequences(
409425
return null;
410426
}
411427
stopSequences = config.asMap().get("stopSequences");
412-
} else if (modelId.startsWith("anthropic.claude")) {
428+
} else if (modelId.startsWith("anthropic.claude") || modelId.startsWith("cohere.command")) {
413429
stopSequences = body.asMap().get("stop_sequences");
430+
} else if (modelId.startsWith("mistral.mistral")) {
431+
stopSequences = body.asMap().get("stop");
414432
}
433+
// meta llama request does not support stop sequences
415434
if (stopSequences != null && stopSequences.isList()) {
416435
return stopSequences.asList().stream()
417436
.filter(Document::isString)
@@ -474,8 +493,38 @@ private static List<String> getStopReasons(
474493
Document stopReason = null;
475494
if (modelId.startsWith("amazon.nova")) {
476495
stopReason = body.asMap().get("stopReason");
477-
} else if (modelId.startsWith("anthropic.claude")) {
496+
} else if (modelId.startsWith("anthropic.claude") || modelId.startsWith("meta.llama")) {
478497
stopReason = body.asMap().get("stop_reason");
498+
} else if (modelId.startsWith("cohere.command-r")) {
499+
stopReason = body.asMap().get("finish_reason");
500+
} else if (modelId.startsWith("cohere.command")) {
501+
List<String> stopReasons = new ArrayList<>();
502+
Document results = body.asMap().get("generations");
503+
if (results == null || !results.isList()) {
504+
return null;
505+
}
506+
for (Document result : results.asList()) {
507+
stopReason = result.asMap().get("finish_reason");
508+
if (stopReason == null || !stopReason.isString()) {
509+
continue;
510+
}
511+
stopReasons.add(stopReason.asString());
512+
}
513+
return stopReasons;
514+
} else if (modelId.startsWith("mistral.mistral")) {
515+
List<String> stopReasons = new ArrayList<>();
516+
Document results = body.asMap().get("outputs");
517+
if (results == null || !results.isList()) {
518+
return null;
519+
}
520+
for (Document result : results.asList()) {
521+
stopReason = result.asMap().get("stop_reason");
522+
if (stopReason == null || !stopReason.isString()) {
523+
continue;
524+
}
525+
stopReasons.add(stopReason.asString());
526+
}
527+
return stopReasons;
479528
}
480529
if (stopReason != null && stopReason.isString()) {
481530
return Collections.singletonList(stopReason.asString());
@@ -534,6 +583,30 @@ private static Long getUsageInputTokens(
534583
return null;
535584
}
536585
count = usage.asMap().get("input_tokens");
586+
} else if (modelId.startsWith("meta.llama")) {
587+
count = body.asMap().get("prompt_token_count");
588+
} else if (modelId.startsWith("cohere.command-r")) {
589+
// approximate input tokens based on prompt length
590+
Document requestBody = executionAttributes.getAttribute(INVOKE_MODEL_REQUEST_BODY);
591+
if (requestBody == null || !requestBody.isMap()) {
592+
return null;
593+
}
594+
String prompt = requestBody.asMap().get("message").asString();
595+
if (prompt == null) {
596+
return null;
597+
}
598+
count = Document.fromNumber(Math.ceil(prompt.length() / CHARS_PER_TOKEN));
599+
} else if (modelId.startsWith("cohere.command") || modelId.startsWith("mistral.mistral")) {
600+
// approximate input tokens based on prompt length
601+
Document requestBody = executionAttributes.getAttribute(INVOKE_MODEL_REQUEST_BODY);
602+
if (requestBody == null || !requestBody.isMap()) {
603+
return null;
604+
}
605+
String prompt = requestBody.asMap().get("prompt").asString();
606+
if (prompt == null) {
607+
return null;
608+
}
609+
count = Document.fromNumber(Math.ceil(prompt.length() / CHARS_PER_TOKEN));
537610
}
538611
if (count != null && count.isNumber()) {
539612
return count.asNumber().longValue();
@@ -604,6 +677,42 @@ private static Long getUsageOutputTokens(
604677
return null;
605678
}
606679
count = usage.asMap().get("output_tokens");
680+
} else if (modelId.startsWith("meta.llama")) {
681+
count = body.asMap().get("generation_token_count");
682+
} else if (modelId.startsWith("cohere.command-r")) {
683+
Document text = body.asMap().get("text");
684+
if (text == null || !text.isString()) {
685+
return null;
686+
}
687+
count = Document.fromNumber(Math.ceil(text.asString().length() / CHARS_PER_TOKEN));
688+
} else if (modelId.startsWith("cohere.command")) {
689+
Document generations = body.asMap().get("generations");
690+
if (generations == null || !generations.isList()) {
691+
return null;
692+
}
693+
long outputLength = 0;
694+
for (Document generation : generations.asList()) {
695+
Document text = generation.asMap().get("text");
696+
if (text == null || !text.isString()) {
697+
continue;
698+
}
699+
outputLength += text.asString().length();
700+
}
701+
count = Document.fromNumber(Math.ceil(outputLength / CHARS_PER_TOKEN));
702+
} else if (modelId.startsWith("mistral.mistral")) {
703+
Document outputs = body.asMap().get("outputs");
704+
if (outputs == null || !outputs.isList()) {
705+
return null;
706+
}
707+
long outputLength = 0;
708+
for (Document output : outputs.asList()) {
709+
Document text = output.asMap().get("text");
710+
if (text == null || !text.isString()) {
711+
continue;
712+
}
713+
outputLength += text.asString().length();
714+
}
715+
count = Document.fromNumber(Math.ceil(outputLength / CHARS_PER_TOKEN));
607716
}
608717
if (count != null && count.isNumber()) {
609718
return count.asNumber().longValue();

0 commit comments

Comments
 (0)