3030import org .springframework .ai .chat .model .ChatResponse ;
3131import org .springframework .ai .chat .prompt .Prompt ;
3232import org .springframework .ai .openai .OpenAiChatModel ;
33+ import org .springframework .ai .openai .OpenAiChatOptions ;
3334import org .springframework .ai .openai .api .OpenAiApi ;
3435import org .springframework .ai .openai .metadata .support .OpenAiApiResponseHeaders ;
3536import org .springframework .beans .factory .annotation .Autowired ;
@@ -73,7 +74,7 @@ void resetMockServer() {
7374 @ Test
7475 void aiResponseContainsAiMetadata () {
7576
76- prepareMock ();
77+ prepareMock (false );
7778
7879 Prompt prompt = new Prompt ("Reach for the sky." );
7980
@@ -118,13 +119,32 @@ void aiResponseContainsAiMetadata() {
118119
119120 response .getResults ().forEach (generation -> {
120121 ChatGenerationMetadata chatGenerationMetadata = generation .getMetadata ();
122+ var logprobs = chatGenerationMetadata .get ("logprobs" );
123+ assertThat (logprobs ).isNull ();
121124 assertThat (chatGenerationMetadata ).isNotNull ();
122125 assertThat (chatGenerationMetadata .getFinishReason ()).isEqualTo ("STOP" );
123126 assertThat (chatGenerationMetadata .getContentFilters ()).isEmpty ();
124127 });
125128 }
126129
127- private void prepareMock () {
130+ @ Test
131+ void aiResponseContainsAiLogprobsMetadata () {
132+
133+ prepareMock (true );
134+
135+ Prompt prompt = new Prompt ("Reach for the sky." , new OpenAiChatOptions .Builder ().logprobs (true ).build ());
136+
137+ ChatResponse response = this .openAiChatClient .call (prompt );
138+
139+ assertThat (response ).isNotNull ();
140+ assertThat (response .getResult ()).isNotNull ();
141+ assertThat (response .getResult ().getMetadata ()).isNotNull ();
142+
143+ var logprobs = response .getResult ().getMetadata ().get ("logprobs" );
144+ assertThat (logprobs ).isNotNull ().isInstanceOf (OpenAiApi .LogProbs .class );
145+ }
146+
147+ private void prepareMock (boolean includeLogprobs ) {
128148
129149 HttpHeaders httpHeaders = new HttpHeaders ();
130150 httpHeaders .set (OpenAiApiResponseHeaders .REQUESTS_LIMIT_HEADER .getName (), "4000" );
@@ -137,34 +157,58 @@ private void prepareMock() {
137157 this .server .expect (requestTo (StringContains .containsString ("/v1/chat/completions" )))
138158 .andExpect (method (HttpMethod .POST ))
139159 .andExpect (header (HttpHeaders .AUTHORIZATION , "Bearer " + TEST_API_KEY ))
140- .andRespond (withSuccess (getJson (), MediaType .APPLICATION_JSON ).headers (httpHeaders ));
160+ .andRespond (withSuccess (getJson (includeLogprobs ), MediaType .APPLICATION_JSON ).headers (httpHeaders ));
141161
142162 }
143163
144- private String getJson () {
164+ private String getBaseJson () {
145165 return """
146- {
147- "id": "chatcmpl-123",
148- "object": "chat.completion",
149- "created": 1677652288,
150- "model": "gpt-3.5-turbo-0613",
151- "choices": [{
152- "index": 0,
153- "message": {
154- "role": "assistant",
155- "content": "I surrender!"
156- },
157- "finish_reason": "stop"
158- }],
159- "usage": {
160- "prompt_tokens": 9,
161- "completion_tokens": 12,
162- "total_tokens": 21
163- }
164- }
166+ {
167+ "id": "chatcmpl-123",
168+ "object": "chat.completion",
169+ "created": 1677652288,
170+ "model": "gpt-3.5-turbo-0613",
171+ "choices": [{
172+ "index": 0,
173+ "message": {
174+ "role": "assistant",
175+ "content": "I surrender!"
176+ },
177+ %s
178+ "finish_reason": "stop"
179+ }],
180+ "usage": {
181+ "prompt_tokens": 9,
182+ "completion_tokens": 12,
183+ "total_tokens": 21
184+ }
185+ }
165186 """ ;
166187 }
167188
189+ private String getJson (boolean includeLogprobs ) {
190+ if (includeLogprobs ) {
191+ String logprobs = """
192+ "logprobs" : {
193+ "content" : [ {
194+ "token" : "I",
195+ "logprob" : -0.029507114,
196+ "bytes" : [ 73 ],
197+ "top_logprobs" : [ ]
198+ }, {
199+ "token" : " surrender!",
200+ "logprob" : -0.061970375,
201+ "bytes" : [ 32, 115, 117, 114, 114, 101, 110, 100, 101, 114, 33 ],
202+ "top_logprobs" : [ ]
203+ } ]
204+ },
205+ """ ;
206+ return String .format (getBaseJson (), logprobs );
207+ }
208+
209+ return String .format (getBaseJson (), "" );
210+ }
211+
168212 @ SpringBootConfiguration
169213 static class Config {
170214
0 commit comments