2222import com .fasterxml .jackson .databind .ObjectMapper ;
2323import org .junit .jupiter .api .Test ;
2424import org .junit .jupiter .api .condition .EnabledIfEnvironmentVariable ;
25+ import org .mockito .ArgumentCaptor ;
2526import software .amazon .awssdk .auth .credentials .EnvironmentVariableCredentialsProvider ;
2627import software .amazon .awssdk .regions .Region ;
2728
3132import org .springframework .ai .embedding .EmbeddingRequest ;
3233import org .springframework .ai .embedding .EmbeddingResponse ;
3334import org .springframework .beans .factory .annotation .Autowired ;
35+ import org .springframework .beans .factory .annotation .Qualifier ;
3436import org .springframework .boot .SpringBootConfiguration ;
3537import org .springframework .boot .test .context .SpringBootTest ;
38+ import org .springframework .boot .test .mock .mockito .SpyBean ;
3639import org .springframework .context .annotation .Bean ;
3740
3841import static org .assertj .core .api .Assertions .assertThat ;
42+ import static org .mockito .Mockito .verify ;
3943
4044@ SpringBootTest
4145@ EnabledIfEnvironmentVariable (named = "AWS_ACCESS_KEY_ID" , matches = ".*" )
@@ -45,6 +49,13 @@ class BedrockCohereEmbeddingModelIT {
4549 @ Autowired
4650 private BedrockCohereEmbeddingModel embeddingModel ;
4751
52+ @ SpyBean
53+ private CohereEmbeddingBedrockApi embeddingApi ;
54+
55+ @ Autowired
56+ @ Qualifier ("embeddingModelStartTruncate" )
57+ private BedrockCohereEmbeddingModel embeddingModelStartTruncate ;
58+
4859 @ Test
4960 void singleEmbedding () {
5061 assertThat (this .embeddingModel ).isNotNull ();
@@ -54,6 +65,77 @@ void singleEmbedding() {
5465 assertThat (this .embeddingModel .dimensions ()).isEqualTo (1024 );
5566 }
5667
68+ @ Test
69+ void truncatesLongText () {
70+ String longText = "Hello World" .repeat (300 );
71+ assertThat (longText .length ()).isGreaterThan (2048 );
72+
73+ EmbeddingResponse embeddingResponse = this .embeddingModel .embedForResponse (List .of (longText ));
74+
75+ assertThat (embeddingResponse .getResults ()).hasSize (1 );
76+ assertThat (embeddingResponse .getResults ().get (0 ).getOutput ()).isNotEmpty ();
77+ assertThat (this .embeddingModel .dimensions ()).isEqualTo (1024 );
78+ }
79+
80+ @ Test
81+ void truncatesMultipleLongTexts () {
82+ String longText1 = "Hello World" .repeat (300 );
83+ String longText2 = "Another Text" .repeat (300 );
84+
85+ EmbeddingResponse embeddingResponse = this .embeddingModel .embedForResponse (List .of (longText1 , longText2 ));
86+
87+ assertThat (embeddingResponse .getResults ()).hasSize (2 );
88+ assertThat (embeddingResponse .getResults ().get (0 ).getOutput ()).isNotEmpty ();
89+ assertThat (embeddingResponse .getResults ().get (1 ).getOutput ()).isNotEmpty ();
90+ assertThat (this .embeddingModel .dimensions ()).isEqualTo (1024 );
91+ }
92+
93+ @ Test
94+ void verifyExactTruncationLength () {
95+ String longText = "x" .repeat (3000 );
96+
97+ ArgumentCaptor <CohereEmbeddingBedrockApi .CohereEmbeddingRequest > requestCaptor = ArgumentCaptor
98+ .forClass (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .class );
99+
100+ EmbeddingResponse embeddingResponse = embeddingModel .embedForResponse (List .of (longText ));
101+
102+ verify (embeddingApi ).embedding (requestCaptor .capture ());
103+ CohereEmbeddingBedrockApi .CohereEmbeddingRequest capturedRequest = requestCaptor .getValue ();
104+
105+ assertThat (capturedRequest .texts ()).hasSize (1 );
106+ assertThat (capturedRequest .texts ().get (0 ).length ()).isLessThanOrEqualTo (2048 );
107+
108+ assertThat (embeddingResponse .getResults ()).hasSize (1 );
109+ assertThat (embeddingResponse .getResults ().get (0 ).getOutput ()).isNotEmpty ();
110+ }
111+
112+ @ Test
113+ void truncatesLongTextFromStart () {
114+ String startMarker = "START_MARKER_" ;
115+ String endMarker = "_END_MARKER" ;
116+ String middlePadding = "x" .repeat (2500 ); // Long enough to force truncation
117+ String longText = startMarker + middlePadding + endMarker ;
118+
119+ assertThat (longText .length ()).isGreaterThan (2048 );
120+
121+ ArgumentCaptor <CohereEmbeddingBedrockApi .CohereEmbeddingRequest > requestCaptor = ArgumentCaptor
122+ .forClass (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .class );
123+
124+ EmbeddingResponse embeddingResponse = this .embeddingModelStartTruncate .embedForResponse (List .of (longText ));
125+
126+ // Verify truncation behavior
127+ verify (embeddingApi ).embedding (requestCaptor .capture ());
128+ String truncatedText = requestCaptor .getValue ().texts ().get (0 );
129+ assertThat (truncatedText .length ()).isLessThanOrEqualTo (2048 );
130+ assertThat (truncatedText ).doesNotContain (startMarker );
131+ assertThat (truncatedText ).endsWith (endMarker );
132+
133+ // Verify embedding response
134+ assertThat (embeddingResponse .getResults ()).hasSize (1 );
135+ assertThat (embeddingResponse .getResults ().get (0 ).getOutput ()).isNotEmpty ();
136+ assertThat (this .embeddingModelStartTruncate .dimensions ()).isEqualTo (1024 );
137+ }
138+
57139 @ Test
58140 void batchEmbedding () {
59141 assertThat (this .embeddingModel ).isNotNull ();
@@ -93,9 +175,27 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi() {
93175 Duration .ofMinutes (2 ));
94176 }
95177
96- @ Bean
178+ @ Bean ( "embeddingModel" )
97179 public BedrockCohereEmbeddingModel cohereAiEmbedding (CohereEmbeddingBedrockApi cohereEmbeddingApi ) {
98- return new BedrockCohereEmbeddingModel (cohereEmbeddingApi );
180+ // custom model that uses the END truncation strategy, instead of the default
181+ // NONE.
182+ return new BedrockCohereEmbeddingModel (cohereEmbeddingApi ,
183+ BedrockCohereEmbeddingOptions .builder ()
184+ .withInputType (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .InputType .SEARCH_DOCUMENT )
185+ .withTruncate (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .Truncate .END )
186+ .build ());
187+ }
188+
189+ @ Bean ("embeddingModelStartTruncate" )
190+ public BedrockCohereEmbeddingModel cohereAiEmbeddingStartTruncate (
191+ CohereEmbeddingBedrockApi cohereEmbeddingApi ) {
192+ // custom model that uses the START truncation strategy, instead of the
193+ // default NONE.
194+ return new BedrockCohereEmbeddingModel (cohereEmbeddingApi ,
195+ BedrockCohereEmbeddingOptions .builder ()
196+ .withInputType (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .InputType .SEARCH_DOCUMENT )
197+ .withTruncate (CohereEmbeddingBedrockApi .CohereEmbeddingRequest .Truncate .START )
198+ .build ());
99199 }
100200
101201 }
0 commit comments