diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java index 852089e2b23..07bfefbe41b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java @@ -29,6 +29,7 @@ import org.antlr.runtime.Token; import org.antlr.runtime.TokenStream; import org.stringtemplate.v4.ST; +import org.stringtemplate.v4.compiler.Compiler; import org.stringtemplate.v4.compiler.STLexer; import org.springframework.ai.chat.messages.Message; @@ -47,6 +48,8 @@ public class PromptTemplate implements PromptTemplateActions, PromptTemplateMess private Map dynamicModel = new HashMap<>(); + private boolean skipRenderValidate = true; + public PromptTemplate(Resource resource) { try (InputStream inputStream = resource.getInputStream()) { this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); @@ -106,6 +109,26 @@ public PromptTemplate(Resource resource, Map model) { } } + public PromptTemplate(Resource resource, boolean skipRenderValidate) { + this(resource); + this.skipRenderValidate = skipRenderValidate; + } + + public PromptTemplate(String template, boolean skipRenderValidate) { + this(template); + this.skipRenderValidate = skipRenderValidate; + } + + public PromptTemplate(String template, Map model, boolean skipRenderValidate) { + this(template, model); + this.skipRenderValidate = skipRenderValidate; + } + + public PromptTemplate(Resource resource, Map model, boolean skipRenderValidate) { + this(resource, model); + this.skipRenderValidate = skipRenderValidate; + } + public void add(String name, Object value) { this.st.add(name, value); this.dynamicModel.put(name, value); @@ -199,15 +222,21 @@ public Set getInputVariables() { if (token.getType() == STLexer.LDELIM && i + 1 < tokens.size() && tokens.get(i + 1).getType() == STLexer.ID) { if (i + 2 < tokens.size() && tokens.get(i + 2).getType() == STLexer.COLON) { - inputVariables.add(tokens.get(i + 1).getText()); - isInsideList = true; + String text = tokens.get(i + 1).getText(); + if (!Compiler.funcs.containsKey(text)) { + inputVariables.add(text); + isInsideList = true; + } } } else if (token.getType() == STLexer.RDELIM) { isInsideList = false; } else if (!isInsideList && token.getType() == STLexer.ID) { - inputVariables.add(token.getText()); + if (!Compiler.funcs.containsKey(token.getText())) { + inputVariables.add(token.getText()); + } + } } @@ -222,6 +251,9 @@ private Set getModelKeys(Map model) { } protected void validate(Map model) { + if (skipRenderValidate) { + return; + } Set templateTokens = getInputVariables(); Set modelKeys = getModelKeys(model); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateTests.java new file mode 100644 index 00000000000..7eb5edec465 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import org.junit.jupiter.api.Test; + +/** + * Unit Tests for {@link PromptTemplateTests}. + * + * @author Sun Yuhan + */ +public class PromptTemplateTests { + @Test + void buildPromptTemplateWithBuiltInFunctions() { + String chatMemoryPrompt = """ + {if(strlen(memory))} + + Hello World! + + --------------------- + {memory} + --------------------- + {endif} + """; + + PromptTemplate promptTemplate = new PromptTemplate(chatMemoryPrompt); + promptTemplate.add("memory", "you are a helpful assistant"); + System.out.println(promptTemplate.render()); + } + + @Test + void buildPromptTemplateSkipRenderValidate() { + String chatMemoryPrompt = """ + {if(strlen(memory))} + + Hello World! + + --------------------- + {memory} + --------------------- + {endif} + """; + + PromptTemplate promptTemplate = new PromptTemplate(chatMemoryPrompt, true); + promptTemplate.add("memory", "you are a helpful assistant"); + System.out.println(promptTemplate.render()); + } +}