@@ -13,7 +13,6 @@ use aws_sdk_bedrockruntime::types::{
1313 ContentBlock , ContentBlockDelta , ConversationRole , ConverseStreamOutput ,
1414 InferenceConfiguration , Message ,
1515} ;
16- use serde:: { Deserialize , Serialize } ;
1716
1817pub struct LargeLanguageModel {
1918 #[ expect( dead_code) ]
@@ -22,65 +21,69 @@ pub struct LargeLanguageModel {
2221 #[ expect( dead_code) ]
2322 bedrock_client : aws_sdk_bedrock:: Client ,
2423 inference_parameters : InferenceConfiguration ,
25- model_id : ArgModel ,
24+ model_id : String ,
2625}
2726
28- #[ derive( Clone , Serialize , Deserialize , Debug , Copy ) ]
29- pub enum ArgModel {
30- Llama270b ,
31- CohereCommand ,
32- ClaudeV2 ,
33- ClaudeV21 ,
34- ClaudeV3Sonnet ,
35- ClaudeV3Haiku ,
36- ClaudeV35Sonnet ,
37- Jurrasic2Ultra ,
38- TitanTextExpressV1 ,
39- Mixtral8x7bInstruct ,
40- Mistral7bInstruct ,
41- MistralLarge ,
42- MistralLarge2 ,
43- }
44-
45- impl std:: fmt:: Display for ArgModel {
46- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
47- write ! ( f, "{}" , self . model_id_str( ) )
48- }
49- }
50-
51- impl ArgModel {
52- pub fn model_id_str ( & self ) -> & ' static str {
53- match self {
54- ArgModel :: ClaudeV2 => "anthropic.claude-v2" ,
55- ArgModel :: ClaudeV21 => "anthropic.claude-v2:1" ,
56- ArgModel :: ClaudeV3Haiku => "anthropic.claude-3-haiku-20240307-v1:0" ,
57- ArgModel :: ClaudeV3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0" ,
58- ArgModel :: ClaudeV35Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
59- ArgModel :: Llama270b => "meta.llama2-70b-chat-v1" ,
60- ArgModel :: CohereCommand => "cohere.command-text-v14" ,
61- ArgModel :: Jurrasic2Ultra => "ai21.j2-ultra-v1" ,
62- ArgModel :: TitanTextExpressV1 => "amazon.titan-text-express-v1" ,
63- ArgModel :: Mixtral8x7bInstruct => "mistral.mixtral-8x7b-instruct-v0:1" ,
64- ArgModel :: Mistral7bInstruct => "mistral.mistral-7b-instruct-v0:2" ,
65- ArgModel :: MistralLarge => "mistral.mistral-large-2402-v1:0" ,
66- ArgModel :: MistralLarge2 => "mistral.mistral-large-2407-v1:0" ,
67- }
68- }
69- }
27+ const MODELS : & [ ( & str , & str ) ] = & [
28+ ( "ClaudeV2" , "anthropic.claude-v2" ) ,
29+ ( "ClaudeV21" , "anthropic.claude-v2:1" ) ,
30+ ( "ClaudeV3Haiku" , "anthropic.claude-3-haiku-20240307-v1:0" ) ,
31+ ( "ClaudeV3Sonnet" , "anthropic.claude-3-sonnet-20240229-v1:0" ) ,
32+ (
33+ "ClaudeV35Sonnet" ,
34+ "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
35+ ) ,
36+ ( "Llama270b" , "meta.llama2-70b-chat-v1" ) ,
37+ ( "CohereCommand" , "cohere.command-text-v14" ) ,
38+ ( "Jurrasic2Ultra" , "ai21.j2-ultra-v1" ) ,
39+ ( "TitanTextExpressV1" , "amazon.titan-text-express-v1" ) ,
40+ ( "Mixtral8x7bInstruct" , "mistral.mixtral-8x7b-instruct-v0:1" ) ,
41+ ( "Mistral7bInstruct" , "mistral.mistral-7b-instruct-v0:2" ) ,
42+ ( "MistralLarge" , "mistral.mistral-large-2402-v1:0" ) ,
43+ ( "MistralLarge2" , "mistral.mistral-large-2407-v1:0" ) ,
44+ ] ;
7045
7146impl LargeLanguageModel {
72- pub async fn new ( ) -> Self {
73- let aws_config = Self :: aws_config ( "us-east-1" , "default" ) . await ;
47+ pub async fn new ( model_id : Option < & str > , region : Option < & str > ) -> anyhow:: Result < Self > {
48+ let model_id = Self :: lookup_model_id ( model_id) ?;
49+ let region = region. unwrap_or ( "us-east-1" ) ;
50+
51+ let aws_config = Self :: aws_config ( region, "default" ) . await ;
7452 let bedrock_runtime_client = aws_sdk_bedrockruntime:: Client :: new ( & aws_config) ;
7553 let bedrock_client = aws_sdk_bedrock:: Client :: new ( & aws_config) ;
7654 let inference_parameters = InferenceConfiguration :: builder ( ) . build ( ) ;
77- Self {
55+ Ok ( Self {
7856 aws_config,
7957 bedrock_runtime_client,
8058 bedrock_client,
8159 inference_parameters,
82- model_id : ArgModel :: ClaudeV3Sonnet ,
60+ model_id,
61+ } )
62+ }
63+
64+ fn lookup_model_id ( model_id : Option < & str > ) -> anyhow:: Result < String > {
65+ let Some ( s) = model_id else {
66+ return Self :: lookup_model_id ( Some ( "ClaudeV3Sonnet" ) ) ;
67+ } ;
68+
69+ if s. contains ( "." ) {
70+ return Ok ( s. to_string ( ) ) ;
8371 }
72+
73+ for & ( key, value) in MODELS {
74+ if key == s {
75+ return Ok ( value. to_string ( ) ) ;
76+ }
77+ }
78+
79+ anyhow:: bail!(
80+ "unknown model-id; try one of the following: [{}]" ,
81+ MODELS
82+ . iter( )
83+ . map( |& ( k, _) | k)
84+ . collect:: <Vec <_>>( )
85+ . join( ", " )
86+ ) ;
8487 }
8588
8689 pub async fn query ( & self , prompt : & str , query : & str ) -> anyhow:: Result < String > {
@@ -89,7 +92,7 @@ impl LargeLanguageModel {
8992 let mut output = self
9093 . bedrock_runtime_client
9194 . converse_stream ( )
92- . model_id ( self . model_id . model_id_str ( ) )
95+ . model_id ( & self . model_id )
9396 . messages (
9497 Message :: builder ( )
9598 . role ( ConversationRole :: Assistant )
0 commit comments