diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5cfe18ac9b4..e4f04a89041 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import io.micrometer.observation.Observation; @@ -348,6 +349,11 @@ public static class DefaultCallResponseSpec implements CallResponseSpec { private final ChatClientObservationConvention observationConvention; + /** + * Used to ensure that the {@link CallResponseSpec} is used only once. + */ + private final AtomicBoolean used = new AtomicBoolean(false); + public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); @@ -450,6 +456,10 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { + if (!this.used.compareAndSet(false, true)) { + throw new IllegalStateException("The CallResponseSpec instance can only be used once."); + } + if (outputFormat != null) { chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), outputFormat); } @@ -494,6 +504,11 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final ChatClientObservationConvention observationConvention; + /** + * Used to ensure that the {@link StreamResponseSpec} is used only once. + */ + private final AtomicBoolean used = new AtomicBoolean(false); + public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); @@ -508,6 +523,11 @@ public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdviso } private Flux doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { + + if (!this.used.compareAndSet(false, true)) { + throw new IllegalStateException("The StreamResponseSpec instance can only be used once."); + } + return Flux.deferContextual(contextView -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 1c72596f490..8b9f231507f 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -58,6 +58,7 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -686,6 +687,19 @@ void buildCallResponseSpecWithNullObservationConvention() { .hasMessage("observationConvention cannot be null"); } + @Test + void whenUseCallResponseSpecTwiceThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("question"); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); + + assertThatNoException().isThrownBy(() -> spec.content()); + assertThatThrownBy(() -> spec.content()).isInstanceOf(IllegalStateException.class) + .hasMessage("The CallResponseSpec instance can only be used once."); + } + @Test void whenSimplePromptThenChatClientResponse() { ChatModel chatModel = mock(ChatModel.class); @@ -1199,6 +1213,19 @@ void buildStreamResponseSpecWithNullObservationConvention() { .hasMessage("observationConvention cannot be null"); } + @Test + void whenUseStreamResponseSpecTwiceThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("question"); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); + + assertThatNoException().isThrownBy(() -> spec.content()); + assertThatThrownBy(() -> spec.content()).isInstanceOf(IllegalStateException.class) + .hasMessage("The StreamResponseSpec instance can only be used once."); + } + @Test void whenSimplePromptThenFluxChatClientResponse() { ChatModel chatModel = mock(ChatModel.class);