Skip to content

Commit ae09200

Browse files
committed
feat: GH-4251 Add a getRequiredVariables method to TemplateRenderer to parse the required variables from the template.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent a5685a1 commit ae09200

File tree

7 files changed

+43
-1
lines changed

7 files changed

+43
-1
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static Stream<ChatModel> openAiCompatibleApis() {
6868
.openAiApi(OpenAiApi.builder()
6969
.baseUrl("https://api.groq.com/openai")
7070
.apiKey(System.getenv("GROQ_API_KEY"))
71-
.build())
71+
.build())
7272
.defaultOptions(forModelName("llama3-8b-8192"))
7373
.build());
7474
}

spring-ai-commons/src/main/java/org/springframework/ai/template/NoOpTemplateRenderer.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.template;
1818

1919
import java.util.Map;
20+
import java.util.Set;
2021

2122
import org.springframework.util.Assert;
2223

@@ -36,4 +37,9 @@ public String apply(String template, Map<String, Object> variables) {
3637
return template;
3738
}
3839

40+
@Override
41+
public Set<String> getRequiredVariables(String template) {
42+
return Set.of();
43+
}
44+
3945
}

spring-ai-commons/src/main/java/org/springframework/ai/template/TemplateRenderer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@
1717
package org.springframework.ai.template;
1818

1919
import java.util.Map;
20+
import java.util.Set;
2021
import java.util.function.BiFunction;
2122

2223
/**
2324
* Renders a template using a given strategy.
2425
*
2526
* @author Thomas Vitale
27+
* @author Sun Yuhan
2628
* @since 1.0.0
2729
*/
2830
public interface TemplateRenderer extends BiFunction<String, Map<String, Object>, String> {
2931

3032
@Override
3133
String apply(String template, Map<String, Object> variables);
3234

35+
Set<String> getRequiredVariables(String template);
36+
3337
}

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.HashMap;
2020
import java.util.Map;
21+
import java.util.Set;
2122

2223
import org.junit.jupiter.api.Test;
2324

@@ -316,6 +317,11 @@ public String apply(String template, Map<String, Object> model) {
316317
return template + " (Rendered by Custom)";
317318
}
318319

320+
@Override
321+
public Set<String> getRequiredVariables(String template) {
322+
return Set.of();
323+
}
324+
319325
}
320326

321327
}

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/SystemPromptTemplateTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import java.util.HashMap;
2828
import java.util.Map;
29+
import java.util.Set;
2930

3031
import static org.assertj.core.api.Assertions.assertThat;
3132
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -324,6 +325,11 @@ public String apply(String template, Map<String, Object> model) {
324325
return template + " (Rendered by Custom)";
325326
}
326327

328+
@Override
329+
public Set<String> getRequiredVariables(String template) {
330+
return Set.of();
331+
}
332+
327333
}
328334

329335
}

spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
* is shared between threads.
5050
*
5151
* @author Thomas Vitale
52+
* @author Sun Yuhan
5253
* @since 1.0.0
5354
*/
5455
public class StTemplateRenderer implements TemplateRenderer {
@@ -110,6 +111,11 @@ public String apply(String template, Map<String, Object> variables) {
110111
return st.render();
111112
}
112113

114+
@Override
115+
public Set<String> getRequiredVariables(String template) {
116+
return getInputVariables(createST(template));
117+
}
118+
113119
private ST createST(String template) {
114120
try {
115121
return new ST(template, this.startDelimiterToken, this.endDelimiterToken);

spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.HashMap;
2020
import java.util.Map;
21+
import java.util.Set;
2122

2223
import org.junit.jupiter.api.Test;
2324

@@ -31,6 +32,7 @@
3132
* Unit tests for {@link StTemplateRenderer}.
3233
*
3334
* @author Thomas Vitale
35+
* @author Sun Yuhan
3436
*/
3537
class StTemplateRendererTests {
3638

@@ -297,4 +299,16 @@ void shouldRenderTemplateWithBuiltInFunctions() {
297299
assertThat(result).isEqualTo("Hello!");
298300
}
299301

302+
/**
303+
* Test whether the required variables can be correctly extracted from the template.
304+
*/
305+
@Test
306+
void shouldCorrectlyExtractedRequiredVariables() {
307+
StTemplateRenderer renderer = StTemplateRenderer.builder().build();
308+
String template = "Person: {name}, Age: {age}";
309+
Set<String> requiredVariables = renderer.getRequiredVariables(template);
310+
311+
assertThat(requiredVariables).contains("name", "age");
312+
}
313+
300314
}

0 commit comments

Comments
 (0)