Skip to content

Commit 4c25399

Browse files
committed
feat: GH-3786 add custom JSON deserializer for OpenAiApi.Embedding
Introduce `OpenAiEmbeddingDeserializer` to handle base64-encoded embeddings returned by the model when encodingFormat=base64 is specified. Signed-off-by: Sun Yuhan <[email protected]>
1 parent aa590e8 commit 4c25399

File tree

3 files changed

+203
-1
lines changed

3 files changed

+203
-1
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.fasterxml.jackson.annotation.JsonInclude;
3030
import com.fasterxml.jackson.annotation.JsonInclude.Include;
3131
import com.fasterxml.jackson.annotation.JsonProperty;
32+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
3233
import reactor.core.publisher.Flux;
3334
import reactor.core.publisher.Mono;
3435

@@ -1793,7 +1794,7 @@ public record ChunkChoice(// @formatter:off
17931794
@JsonIgnoreProperties(ignoreUnknown = true)
17941795
public record Embedding(// @formatter:off
17951796
@JsonProperty("index") Integer index,
1796-
@JsonProperty("embedding") float[] embedding,
1797+
@JsonProperty("embedding") @JsonDeserialize(using = OpenAiEmbeddingDeserializer.class) float[] embedding,
17971798
@JsonProperty("object") String object) { // @formatter:on
17981799

17991800
/**
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.api;
18+
19+
import com.fasterxml.jackson.core.JacksonException;
20+
import com.fasterxml.jackson.core.JsonParser;
21+
import com.fasterxml.jackson.core.JsonToken;
22+
import com.fasterxml.jackson.databind.DeserializationContext;
23+
import com.fasterxml.jackson.databind.JsonDeserializer;
24+
25+
import java.io.IOException;
26+
import java.nio.ByteBuffer;
27+
import java.nio.ByteOrder;
28+
import java.util.Base64;
29+
30+
/**
31+
* Used to deserialize the `embedding` field returned by the model.
32+
* <p>
33+
* Supports two input formats:
34+
* <ol>
35+
* <li>{@code float[]} - returned directly as-is.</li>
36+
* <li>A Base64-encoded string representing a float array stored in little-endian format.
37+
* The string is first decoded into a byte array, then converted into a
38+
* {@code float[]}.</li>
39+
* </ol>
40+
*
41+
* @author Sun Yuhan
42+
*/
43+
public class OpenAiEmbeddingDeserializer extends JsonDeserializer<float[]> {
44+
45+
@Override
46+
public float[] deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
47+
throws IOException, JacksonException {
48+
JsonToken token = jsonParser.currentToken();
49+
if (token == JsonToken.START_ARRAY) {
50+
return jsonParser.readValueAs(float[].class);
51+
}
52+
else if (token == JsonToken.VALUE_STRING) {
53+
String base64 = jsonParser.getValueAsString();
54+
byte[] decodedBytes = Base64.getDecoder().decode(base64);
55+
56+
ByteBuffer byteBuffer = ByteBuffer.wrap(decodedBytes);
57+
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
58+
59+
int floatCount = decodedBytes.length / Float.BYTES;
60+
float[] embeddingArray = new float[floatCount];
61+
62+
for (int i = 0; i < floatCount; i++) {
63+
embeddingArray[i] = byteBuffer.getFloat();
64+
}
65+
return embeddingArray;
66+
}
67+
else {
68+
throw new IOException("Illegal embedding: " + token);
69+
}
70+
}
71+
72+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package org.springframework.ai.openai.api;
2+
3+
import com.fasterxml.jackson.core.JsonParser;
4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import com.fasterxml.jackson.core.JsonToken;
6+
import com.fasterxml.jackson.databind.DeserializationContext;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import org.junit.jupiter.api.Test;
9+
10+
import java.io.IOException;
11+
import java.nio.ByteBuffer;
12+
import java.nio.ByteOrder;
13+
import java.util.Base64;
14+
15+
import static org.junit.Assert.assertThrows;
16+
import static org.junit.jupiter.api.Assertions.*;
17+
import static org.mockito.Mockito.mock;
18+
import static org.mockito.Mockito.when;
19+
20+
/**
21+
* Unit tests for {@link OpenAiEmbeddingDeserializer}
22+
*
23+
* @author Sun Yuhan
24+
*/
25+
class OpenAiEmbeddingDeserializerTests {
26+
27+
private final OpenAiEmbeddingDeserializer deserializer = new OpenAiEmbeddingDeserializer();
28+
29+
private final ObjectMapper mapper = new ObjectMapper();
30+
31+
@Test
32+
void testDeserializeFloatArray() throws Exception {
33+
JsonParser parser = mock(JsonParser.class);
34+
DeserializationContext context = mock(DeserializationContext.class);
35+
36+
when(parser.currentToken()).thenReturn(JsonToken.START_ARRAY);
37+
float[] expected = new float[] { 1.0f, 2.0f, 3.0f };
38+
when(parser.readValueAs(float[].class)).thenReturn(expected);
39+
40+
float[] result = deserializer.deserialize(parser, context);
41+
assertArrayEquals(expected, result);
42+
}
43+
44+
@Test
45+
void testDeserializeBase64String() throws Exception {
46+
float[] original = new float[] { 4.2f, -1.5f, 0.0f };
47+
ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES);
48+
buffer.order(ByteOrder.LITTLE_ENDIAN);
49+
for (float v : original) {
50+
buffer.putFloat(v);
51+
}
52+
String base64 = Base64.getEncoder().encodeToString(buffer.array());
53+
54+
JsonParser parser = mock(JsonParser.class);
55+
DeserializationContext context = mock(DeserializationContext.class);
56+
57+
when(parser.currentToken()).thenReturn(JsonToken.VALUE_STRING);
58+
when(parser.getValueAsString()).thenReturn(base64);
59+
60+
float[] result = deserializer.deserialize(parser, context);
61+
62+
assertArrayEquals(original, result, 0.0001f);
63+
}
64+
65+
@Test
66+
void testDeserializeIllegalToken() {
67+
JsonParser parser = mock(JsonParser.class);
68+
DeserializationContext context = mock(DeserializationContext.class);
69+
70+
when(parser.currentToken()).thenReturn(JsonToken.VALUE_NUMBER_INT);
71+
72+
IOException e = assertThrows(IOException.class, () -> deserializer.deserialize(parser, context));
73+
assertTrue(e.getMessage().contains("Illegal embedding"));
74+
}
75+
76+
@Test
77+
void testDeserializeEmbeddingWithFloatArray() throws Exception {
78+
String json = """
79+
{
80+
"index": 1,
81+
"embedding": [1.0, 2.0, 3.0],
82+
"object": "embedding"
83+
}
84+
""";
85+
OpenAiApi.Embedding embedding = mapper.readValue(json, OpenAiApi.Embedding.class);
86+
assertEquals(1, embedding.index());
87+
assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, embedding.embedding(), 0.0001f);
88+
assertEquals("embedding", embedding.object());
89+
}
90+
91+
@Test
92+
void testDeserializeEmbeddingWithBase64String() throws Exception {
93+
float[] original = new float[] { 4.2f, -1.5f, 0.0f };
94+
ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES);
95+
buffer.order(ByteOrder.LITTLE_ENDIAN);
96+
for (float v : original)
97+
buffer.putFloat(v);
98+
String base64 = Base64.getEncoder().encodeToString(buffer.array());
99+
100+
String json = """
101+
{
102+
"index": 2,
103+
"embedding": "%s",
104+
"object": "embedding"
105+
}
106+
""".formatted(base64);
107+
108+
OpenAiApi.Embedding embedding = mapper.readValue(json, OpenAiApi.Embedding.class);
109+
assertEquals(2, embedding.index());
110+
assertArrayEquals(original, embedding.embedding(), 0.0001f);
111+
assertEquals("embedding", embedding.object());
112+
}
113+
114+
@Test
115+
void testDeserializeEmbeddingWithWrongType() {
116+
String json = """
117+
{
118+
"index": 3,
119+
"embedding": 123,
120+
"object": "embedding"
121+
}
122+
""";
123+
JsonProcessingException ex = assertThrows(JsonProcessingException.class, () -> {
124+
mapper.readValue(json, OpenAiApi.Embedding.class);
125+
});
126+
assertTrue(ex.getMessage().contains("Illegal embedding"));
127+
}
128+
129+
}

0 commit comments

Comments
 (0)