Skip to content

Commit 9b24765

Browse files
committed
fix: Make sure that XXXResponseSpec instances can only be used once
- Add an `AtomicBoolean` to DefaultCallResponseSpec and DefaultStreamResponseSpec to track whether the instance has been used - Check if the instance has already been used before calling `advisorChain.nextCall`, and throw an `IllegalStateException` if it has - Add unit tests to confirm that case Signed-off-by: YunKui Lu <[email protected]>
1 parent daf1274 commit 9b24765

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.List;
2727
import java.util.Map;
2828
import java.util.Optional;
29+
import java.util.concurrent.atomic.AtomicBoolean;
2930
import java.util.function.Consumer;
3031

3132
import io.micrometer.observation.Observation;
@@ -348,6 +349,11 @@ public static class DefaultCallResponseSpec implements CallResponseSpec {
348349

349350
private final ChatClientObservationConvention observationConvention;
350351

352+
/**
353+
* Used to ensure that the {@link CallResponseSpec} is used only once.
354+
*/
355+
private final AtomicBoolean used = new AtomicBoolean(false);
356+
351357
public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
352358
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
353359
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
@@ -450,6 +456,10 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c
450456
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest,
451457
@Nullable String outputFormat) {
452458

459+
if (!this.used.compareAndSet(false, true)) {
460+
throw new IllegalStateException("The CallResponseSpec instance can only be used once.");
461+
}
462+
453463
if (outputFormat != null) {
454464
chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), outputFormat);
455465
}
@@ -494,6 +504,11 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec {
494504

495505
private final ChatClientObservationConvention observationConvention;
496506

507+
/**
508+
* Used to ensure that the {@link StreamResponseSpec} is used only once.
509+
*/
510+
private final AtomicBoolean used = new AtomicBoolean(false);
511+
497512
public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
498513
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
499514
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
@@ -508,6 +523,11 @@ public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdviso
508523
}
509524

510525
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
526+
527+
if (!this.used.compareAndSet(false, true)) {
528+
throw new IllegalStateException("The StreamResponseSpec instance can only be used once.");
529+
}
530+
511531
return Flux.deferContextual(contextView -> {
512532

513533
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.springframework.util.MimeTypeUtils;
5959

6060
import static org.assertj.core.api.Assertions.assertThat;
61+
import static org.assertj.core.api.Assertions.assertThatNoException;
6162
import static org.assertj.core.api.Assertions.assertThatThrownBy;
6263
import static org.mockito.BDDMockito.given;
6364
import static org.mockito.Mockito.mock;
@@ -686,6 +687,19 @@ void buildCallResponseSpecWithNullObservationConvention() {
686687
.hasMessage("observationConvention cannot be null");
687688
}
688689

690+
@Test
691+
void whenUseCallResponseSpecTwiceThenThrow() {
692+
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
693+
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
694+
.prompt("question");
695+
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
696+
.call();
697+
698+
assertThatNoException().isThrownBy(() -> spec.content());
699+
assertThatThrownBy(() -> spec.content()).isInstanceOf(IllegalStateException.class)
700+
.hasMessage("The CallResponseSpec instance can only be used once.");
701+
}
702+
689703
@Test
690704
void whenSimplePromptThenChatClientResponse() {
691705
ChatModel chatModel = mock(ChatModel.class);
@@ -1199,6 +1213,19 @@ void buildStreamResponseSpecWithNullObservationConvention() {
11991213
.hasMessage("observationConvention cannot be null");
12001214
}
12011215

1216+
@Test
1217+
void whenUseStreamResponseSpecTwiceThenThrow() {
1218+
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
1219+
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
1220+
.prompt("question");
1221+
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
1222+
.stream();
1223+
1224+
assertThatNoException().isThrownBy(() -> spec.content());
1225+
assertThatThrownBy(() -> spec.content()).isInstanceOf(IllegalStateException.class)
1226+
.hasMessage("The StreamResponseSpec instance can only be used once.");
1227+
}
1228+
12021229
@Test
12031230
void whenSimplePromptThenFluxChatClientResponse() {
12041231
ChatModel chatModel = mock(ChatModel.class);

0 commit comments

Comments
 (0)