Skip to content

Commit 9e43b7e

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
feat: spring-projectsGH-3786 add custom JSON deserializer for OpenAiApi.Embedding
Introduce `OpenAiEmbeddingDeserializer` to handle base64-encoded embeddings returned by the model when encodingFormat=base64 is specified. Signed-off-by: Sun Yuhan <[email protected]>
1 parent 4ed5417 commit 9e43b7e

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.fasterxml.jackson.annotation.JsonInclude;
3030
import com.fasterxml.jackson.annotation.JsonInclude.Include;
3131
import com.fasterxml.jackson.annotation.JsonProperty;
32+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
3233
import reactor.core.publisher.Flux;
3334
import reactor.core.publisher.Mono;
3435

@@ -1876,7 +1877,7 @@ public record ChunkChoice(// @formatter:off
18761877
@JsonIgnoreProperties(ignoreUnknown = true)
18771878
public record Embedding(// @formatter:off
18781879
@JsonProperty("index") Integer index,
1879-
@JsonProperty("embedding") float[] embedding,
1880+
@JsonProperty("embedding") @JsonDeserialize(using = OpenAiEmbeddingDeserializer.class) float[] embedding,
18801881
@JsonProperty("object") String object) { // @formatter:on
18811882

18821883
/**
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.api;
18+
19+
import java.io.IOException;
20+
import java.nio.ByteBuffer;
21+
import java.nio.ByteOrder;
22+
import java.util.Base64;
23+
24+
import com.fasterxml.jackson.core.JacksonException;
25+
import com.fasterxml.jackson.core.JsonParser;
26+
import com.fasterxml.jackson.core.JsonToken;
27+
import com.fasterxml.jackson.databind.DeserializationContext;
28+
import com.fasterxml.jackson.databind.JsonDeserializer;
29+
30+
/**
31+
* Used to deserialize the `embedding` field returned by the model.
32+
* <p>
33+
* Supports two input formats:
34+
* <ol>
35+
* <li>{@code float[]} - returned directly as-is.</li>
36+
* <li>A Base64-encoded string representing a float array stored in little-endian format.
37+
* The string is first decoded into a byte array, then converted into a
38+
* {@code float[]}.</li>
39+
* </ol>
40+
*
41+
* @author Sun Yuhan
42+
*/
43+
public class OpenAiEmbeddingDeserializer extends JsonDeserializer<float[]> {
44+
45+
@Override
46+
public float[] deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
47+
throws IOException, JacksonException {
48+
JsonToken token = jsonParser.currentToken();
49+
if (token == JsonToken.START_ARRAY) {
50+
return jsonParser.readValueAs(float[].class);
51+
}
52+
else if (token == JsonToken.VALUE_STRING) {
53+
String base64 = jsonParser.getValueAsString();
54+
byte[] decodedBytes = Base64.getDecoder().decode(base64);
55+
56+
ByteBuffer byteBuffer = ByteBuffer.wrap(decodedBytes);
57+
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
58+
59+
int floatCount = decodedBytes.length / Float.BYTES;
60+
float[] embeddingArray = new float[floatCount];
61+
62+
for (int i = 0; i < floatCount; i++) {
63+
embeddingArray[i] = byteBuffer.getFloat();
64+
}
65+
return embeddingArray;
66+
}
67+
else {
68+
throw new IOException("Illegal embedding: " + token);
69+
}
70+
}
71+
72+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.api;
18+
19+
import java.io.IOException;
20+
import java.nio.ByteBuffer;
21+
import java.nio.ByteOrder;
22+
import java.util.Base64;
23+
24+
import com.fasterxml.jackson.core.JsonParser;
25+
import com.fasterxml.jackson.core.JsonProcessingException;
26+
import com.fasterxml.jackson.core.JsonToken;
27+
import com.fasterxml.jackson.databind.DeserializationContext;
28+
import com.fasterxml.jackson.databind.ObjectMapper;
29+
import org.junit.jupiter.api.Test;
30+
31+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
32+
import static org.junit.jupiter.api.Assertions.assertEquals;
33+
import static org.junit.jupiter.api.Assertions.assertThrows;
34+
import static org.junit.jupiter.api.Assertions.assertTrue;
35+
import static org.mockito.Mockito.mock;
36+
import static org.mockito.Mockito.when;
37+
38+
/**
39+
* Unit tests for {@link OpenAiEmbeddingDeserializer}
40+
*
41+
* @author Sun Yuhan
42+
*/
43+
class OpenAiEmbeddingDeserializerTests {
44+
45+
private final OpenAiEmbeddingDeserializer deserializer = new OpenAiEmbeddingDeserializer();
46+
47+
private final ObjectMapper mapper = new ObjectMapper();
48+
49+
@Test
50+
void testDeserializeFloatArray() throws Exception {
51+
JsonParser parser = mock(JsonParser.class);
52+
DeserializationContext context = mock(DeserializationContext.class);
53+
54+
when(parser.currentToken()).thenReturn(JsonToken.START_ARRAY);
55+
float[] expected = new float[] { 1.0f, 2.0f, 3.0f };
56+
when(parser.readValueAs(float[].class)).thenReturn(expected);
57+
58+
float[] result = this.deserializer.deserialize(parser, context);
59+
assertArrayEquals(expected, result);
60+
}
61+
62+
@Test
63+
void testDeserializeBase64String() throws Exception {
64+
float[] original = new float[] { 4.2f, -1.5f, 0.0f };
65+
ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES);
66+
buffer.order(ByteOrder.LITTLE_ENDIAN);
67+
for (float v : original) {
68+
buffer.putFloat(v);
69+
}
70+
String base64 = Base64.getEncoder().encodeToString(buffer.array());
71+
72+
JsonParser parser = mock(JsonParser.class);
73+
DeserializationContext context = mock(DeserializationContext.class);
74+
75+
when(parser.currentToken()).thenReturn(JsonToken.VALUE_STRING);
76+
when(parser.getValueAsString()).thenReturn(base64);
77+
78+
float[] result = this.deserializer.deserialize(parser, context);
79+
80+
assertArrayEquals(original, result, 0.0001f);
81+
}
82+
83+
@Test
84+
void testDeserializeIllegalToken() {
85+
JsonParser parser = mock(JsonParser.class);
86+
DeserializationContext context = mock(DeserializationContext.class);
87+
88+
when(parser.currentToken()).thenReturn(JsonToken.VALUE_NUMBER_INT);
89+
90+
IOException e = assertThrows(IOException.class, () -> this.deserializer.deserialize(parser, context));
91+
assertTrue(e.getMessage().contains("Illegal embedding"));
92+
}
93+
94+
@Test
95+
void testDeserializeEmbeddingWithFloatArray() throws Exception {
96+
String json = """
97+
{
98+
"index": 1,
99+
"embedding": [1.0, 2.0, 3.0],
100+
"object": "embedding"
101+
}
102+
""";
103+
OpenAiApi.Embedding embedding = this.mapper.readValue(json, OpenAiApi.Embedding.class);
104+
assertEquals(1, embedding.index());
105+
assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, embedding.embedding(), 0.0001f);
106+
assertEquals("embedding", embedding.object());
107+
}
108+
109+
@Test
110+
void testDeserializeEmbeddingWithBase64String() throws Exception {
111+
float[] original = new float[] { 4.2f, -1.5f, 0.0f };
112+
ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES);
113+
buffer.order(ByteOrder.LITTLE_ENDIAN);
114+
for (float v : original) {
115+
buffer.putFloat(v);
116+
}
117+
String base64 = Base64.getEncoder().encodeToString(buffer.array());
118+
119+
String json = """
120+
{
121+
"index": 2,
122+
"embedding": "%s",
123+
"object": "embedding"
124+
}
125+
""".formatted(base64);
126+
127+
OpenAiApi.Embedding embedding = this.mapper.readValue(json, OpenAiApi.Embedding.class);
128+
assertEquals(2, embedding.index());
129+
assertArrayEquals(original, embedding.embedding(), 0.0001f);
130+
assertEquals("embedding", embedding.object());
131+
}
132+
133+
@Test
134+
void testDeserializeEmbeddingWithWrongType() {
135+
String json = """
136+
{
137+
"index": 3,
138+
"embedding": 123,
139+
"object": "embedding"
140+
}
141+
""";
142+
JsonProcessingException ex = assertThrows(JsonProcessingException.class,
143+
() -> this.mapper.readValue(json, OpenAiApi.Embedding.class));
144+
assertTrue(ex.getMessage().contains("Illegal embedding"));
145+
}
146+
147+
}

0 commit comments

Comments
 (0)