Skip to content

Commit 9080690

Browse files
committed
feat(vcr): add Spring AI model wrappers for VCR
Implements VCR wrappers for Spring AI models: - VCRSpringAIEmbeddingModel: wraps EmbeddingModel for recording/replaying - VCRSpringAIChatModel: wraps ChatModel for recording/replaying chat Features: - Full Spring AI interface implementation - Support for embedForResponse() and call() methods - Redis-backed cassette storage integration - In-memory cassette cache for unit testing - Statistics tracking for cache hits, misses, and recordings Unit tests verify recording and playback behavior.
1 parent 4917e6d commit 9080690

File tree

4 files changed

+1336
-0
lines changed

4 files changed

+1336
-0
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
package com.redis.vl.test.vcr;
2+
3+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
4+
import java.util.HashMap;
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.concurrent.atomic.AtomicInteger;
8+
import org.springframework.ai.chat.messages.AssistantMessage;
9+
import org.springframework.ai.chat.messages.Message;
10+
import org.springframework.ai.chat.model.ChatModel;
11+
import org.springframework.ai.chat.model.ChatResponse;
12+
import org.springframework.ai.chat.model.Generation;
13+
import org.springframework.ai.chat.prompt.Prompt;
14+
15+
/**
16+
* VCR wrapper for Spring AI ChatModel that records and replays LLM responses.
17+
*
18+
* <p>This class implements the ChatModel interface, allowing it to be used as a drop-in replacement
19+
* for any Spring AI chat model. It provides VCR (Video Cassette Recorder) functionality to record
20+
* LLM responses during test execution and replay them in subsequent runs.
21+
*
22+
* <p>Usage:
23+
*
24+
* <pre>{@code
25+
* ChatModel openAiModel = new OpenAiChatModel(openAiApi);
26+
*
27+
* VCRSpringAIChatModel vcrModel = new VCRSpringAIChatModel(openAiModel);
28+
* vcrModel.setMode(VCRMode.PLAYBACK_OR_RECORD);
29+
* vcrModel.setTestId("MyTest.testMethod");
30+
*
31+
* // Use exactly like the original model
32+
* String response = vcrModel.call("Hello");
33+
* }</pre>
34+
*/
35+
@SuppressFBWarnings(
36+
value = "EI_EXPOSE_REP2",
37+
justification = "Delegate is intentionally stored and exposed for VCR functionality")
38+
public final class VCRSpringAIChatModel implements ChatModel {
39+
40+
private final ChatModel delegate;
41+
private VCRCassetteStore cassetteStore;
42+
private VCRMode mode = VCRMode.PLAYBACK_OR_RECORD;
43+
private String testId = "unknown";
44+
private final AtomicInteger callCounter = new AtomicInteger(0);
45+
46+
// In-memory cassette storage for unit tests
47+
private final Map<String, String> cassettes = new HashMap<>();
48+
49+
// Statistics
50+
private int cacheHits = 0;
51+
private int cacheMisses = 0;
52+
private int recordedCount = 0;
53+
54+
/**
55+
* Creates a new VCRSpringAIChatModel wrapping the given delegate.
56+
*
57+
* @param delegate The actual ChatModel to wrap
58+
*/
59+
public VCRSpringAIChatModel(ChatModel delegate) {
60+
this.delegate = delegate;
61+
}
62+
63+
/**
64+
* Creates a new VCRSpringAIChatModel wrapping the given delegate with Redis storage.
65+
*
66+
* @param delegate The actual ChatModel to wrap
67+
* @param cassetteStore The cassette store for persistence
68+
*/
69+
@SuppressFBWarnings(
70+
value = "EI_EXPOSE_REP2",
71+
justification = "VCRCassetteStore is intentionally shared")
72+
public VCRSpringAIChatModel(ChatModel delegate, VCRCassetteStore cassetteStore) {
73+
this.delegate = delegate;
74+
this.cassetteStore = cassetteStore;
75+
}
76+
77+
/**
78+
* Sets the VCR mode.
79+
*
80+
* @param mode The VCR mode to use
81+
*/
82+
public void setMode(VCRMode mode) {
83+
this.mode = mode;
84+
}
85+
86+
/**
87+
* Gets the current VCR mode.
88+
*
89+
* @return The current VCR mode
90+
*/
91+
public VCRMode getMode() {
92+
return mode;
93+
}
94+
95+
/**
96+
* Sets the test identifier for cassette key generation.
97+
*
98+
* @param testId The test identifier (typically ClassName.methodName)
99+
*/
100+
public void setTestId(String testId) {
101+
this.testId = testId;
102+
}
103+
104+
/**
105+
* Gets the current test identifier.
106+
*
107+
* @return The current test identifier
108+
*/
109+
public String getTestId() {
110+
return testId;
111+
}
112+
113+
/** Resets the call counter. Useful when starting a new test method. */
114+
public void resetCallCounter() {
115+
callCounter.set(0);
116+
}
117+
118+
/**
119+
* Gets the underlying delegate model.
120+
*
121+
* @return The wrapped ChatModel
122+
*/
123+
@SuppressFBWarnings(
124+
value = "EI_EXPOSE_REP",
125+
justification = "Intentional exposure of delegate for advanced use cases")
126+
public ChatModel getDelegate() {
127+
return delegate;
128+
}
129+
130+
/**
131+
* Preloads a cassette for testing purposes.
132+
*
133+
* @param key The cassette key
134+
* @param response The response text to cache
135+
*/
136+
public void preloadCassette(String key, String response) {
137+
cassettes.put(key, response);
138+
}
139+
140+
/**
141+
* Gets the number of cache hits.
142+
*
143+
* @return Cache hit count
144+
*/
145+
public int getCacheHits() {
146+
return cacheHits;
147+
}
148+
149+
/**
150+
* Gets the number of cache misses.
151+
*
152+
* @return Cache miss count
153+
*/
154+
public int getCacheMisses() {
155+
return cacheMisses;
156+
}
157+
158+
/**
159+
* Gets the number of recorded responses.
160+
*
161+
* @return Recorded count
162+
*/
163+
public int getRecordedCount() {
164+
return recordedCount;
165+
}
166+
167+
/** Resets all statistics. */
168+
public void resetStatistics() {
169+
cacheHits = 0;
170+
cacheMisses = 0;
171+
recordedCount = 0;
172+
}
173+
174+
@Override
175+
public ChatResponse call(Prompt prompt) {
176+
String responseText =
177+
callInternal(
178+
() -> {
179+
ChatResponse response = delegate.call(prompt);
180+
return response.getResult().getOutput().getText();
181+
});
182+
Generation generation = new Generation(new AssistantMessage(responseText));
183+
return new ChatResponse(List.of(generation));
184+
}
185+
186+
@Override
187+
public String call(String message) {
188+
return callInternal(() -> delegate.call(message));
189+
}
190+
191+
@Override
192+
public String call(Message... messages) {
193+
return callInternal(() -> delegate.call(messages));
194+
}
195+
196+
private String callInternal(java.util.function.Supplier<String> delegateCall) {
197+
if (mode == VCRMode.OFF) {
198+
return delegateCall.get();
199+
}
200+
201+
String key = formatKey();
202+
203+
if (mode.isPlaybackMode()) {
204+
String cached = loadCassette(key);
205+
if (cached != null) {
206+
cacheHits++;
207+
return cached;
208+
}
209+
210+
if (mode == VCRMode.PLAYBACK) {
211+
throw new VCRCassetteMissingException(key, testId);
212+
}
213+
214+
// PLAYBACK_OR_RECORD - fall through to record
215+
}
216+
217+
// Record mode or cache miss in PLAYBACK_OR_RECORD
218+
cacheMisses++;
219+
String response = delegateCall.get();
220+
saveCassette(key, response);
221+
recordedCount++;
222+
223+
return response;
224+
}
225+
226+
private String loadCassette(String key) {
227+
// Check in-memory first
228+
String inMemory = cassettes.get(key);
229+
if (inMemory != null) {
230+
return inMemory;
231+
}
232+
233+
// Check Redis if available
234+
if (cassetteStore != null) {
235+
com.google.gson.JsonObject cassette = cassetteStore.retrieve(key);
236+
if (cassette != null && cassette.has("response")) {
237+
return cassette.get("response").getAsString();
238+
}
239+
}
240+
241+
return null;
242+
}
243+
244+
private void saveCassette(String key, String response) {
245+
// Save to in-memory
246+
cassettes.put(key, response);
247+
248+
// Save to Redis if available
249+
if (cassetteStore != null) {
250+
com.google.gson.JsonObject cassette = new com.google.gson.JsonObject();
251+
cassette.addProperty("response", response);
252+
cassette.addProperty("testId", testId);
253+
cassette.addProperty("type", "chat");
254+
cassetteStore.store(key, cassette);
255+
}
256+
}
257+
258+
private String formatKey() {
259+
int index = callCounter.incrementAndGet();
260+
return String.format("vcr:chat:%s:%04d", testId, index);
261+
}
262+
}

0 commit comments

Comments
 (0)