Skip to content

Commit 8fa8224

Browse files
authored
Merge pull request #1716 from quarkiverse/#1713
Introduce support for Audio in AI services
2 parents c4ae8ff + 258e3c1 commit 8fa8224

File tree

11 files changed

+123
-3308
lines changed

11 files changed

+123
-3308
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,11 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
16701670
MethodParameterInfo imageUrlParam = method.parameters().get(imageParamPosition.get());
16711671
validateImageUrlParam(imageUrlParam);
16721672
}
1673+
Optional<Integer> audioParamPosition = determineAudioParamPosition(method);
1674+
if (audioParamPosition.isPresent()) {
1675+
MethodParameterInfo audioUrlParam = method.parameters().get(audioParamPosition.get());
1676+
validateAudioUrlParam(audioUrlParam);
1677+
}
16731678
Optional<Integer> pdfParamPosition = determinePdfParamPosition(method);
16741679
if (pdfParamPosition.isPresent()) {
16751680
MethodParameterInfo pdfUrlParam = method.parameters().get(pdfParamPosition.get());
@@ -1693,7 +1698,7 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
16931698
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
16941699
AiServiceMethodCreateInfo.TemplateInfo.fromText(userMessageTemplate,
16951700
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)),
1696-
userNameParamPosition, imageParamPosition, pdfParamPosition);
1701+
userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition);
16971702
} else {
16981703
Optional<AnnotationInstance> userMessageOnMethodParam = method.annotations(LangChain4jDotNames.USER_MESSAGE)
16991704
.stream()
@@ -1706,11 +1711,11 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
17061711
Short.valueOf(userMessageOnMethodParam.get().target().asMethodParameter().position())
17071712
.intValue(),
17081713
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)),
1709-
userNameParamPosition, imageParamPosition, pdfParamPosition);
1714+
userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition);
17101715
} else {
17111716
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(
17121717
userMessageOnMethodParam.get().target().asMethodParameter().position(),
1713-
userNameParamPosition, imageParamPosition, pdfParamPosition);
1718+
userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition);
17141719
}
17151720
} else {
17161721
int numOfMethodParamsUsedInSystemMessage = 0;
@@ -1731,15 +1736,15 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
17311736
}
17321737
if (method.parametersCount() == 1) {
17331738
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition,
1734-
imageParamPosition, pdfParamPosition);
1739+
imageParamPosition, audioParamPosition, pdfParamPosition);
17351740
}
17361741
throw illegalConfigurationForMethod(
17371742
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
17381743
method);
17391744
} else {
17401745
// all method parameters are present in the system message, so there is no user message
17411746
return new AiServiceMethodCreateInfo.UserMessageInfo(Optional.empty(), Optional.empty(), Optional.empty(),
1742-
Optional.empty(), Optional.empty());
1747+
Optional.empty(), Optional.empty(), Optional.empty());
17431748
}
17441749
}
17451750
}
@@ -1756,6 +1761,28 @@ private static Optional<Integer> determineImageParamPosition(MethodInfo method)
17561761
.map(pi -> (int) pi.position()).findFirst();
17571762
}
17581763

1764+
private static Optional<Integer> determineAudioParamPosition(MethodInfo method) {
1765+
Optional<Integer> result = method.annotations(LangChain4jDotNames.AUDIO_URL).stream().filter(
1766+
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
1767+
if (result.isPresent()) {
1768+
return result;
1769+
}
1770+
// we don't need @AudioUrl if the parameter is of type Image
1771+
return method.parameters().stream().filter(pi -> pi.type().name().equals(LangChain4jDotNames.AUDIO))
1772+
.map(pi -> (int) pi.position()).findFirst();
1773+
}
1774+
1775+
private static Optional<Integer> determinePdfParamPosition(MethodInfo method) {
1776+
Optional<Integer> result = method.annotations(LangChain4jDotNames.PDF_URL).stream().filter(
1777+
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
1778+
if (result.isPresent()) {
1779+
return result;
1780+
}
1781+
// we don't need @PdfUrl if the parameter is of type PdfFile
1782+
return method.parameters().stream().filter(pi -> pi.type().name().equals(LangChain4jDotNames.PDF_FILE))
1783+
.map(pi -> (int) pi.position()).findFirst();
1784+
}
1785+
17591786
private void validateImageUrlParam(MethodParameterInfo param) {
17601787
if (param == null) {
17611788
throw new IllegalArgumentException("Unhandled @ImageUrl annotation");
@@ -1769,15 +1796,17 @@ private void validateImageUrlParam(MethodParameterInfo param) {
17691796
throw new IllegalArgumentException("Unhandled @ImageUrl type '" + type.name() + "'");
17701797
}
17711798

1772-
private static Optional<Integer> determinePdfParamPosition(MethodInfo method) {
1773-
Optional<Integer> result = method.annotations(LangChain4jDotNames.PDF_URL).stream().filter(
1774-
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
1775-
if (result.isPresent()) {
1776-
return result;
1799+
private void validateAudioUrlParam(MethodParameterInfo param) {
1800+
if (param == null) {
1801+
throw new IllegalArgumentException("Unhandled @ImageUrl annotation");
17771802
}
1778-
// we don't need @PdfUrl if the parameter is of type PdfFile
1779-
return method.parameters().stream().filter(pi -> pi.type().name().equals(LangChain4jDotNames.PDF_FILE))
1780-
.map(pi -> (int) pi.position()).findFirst();
1803+
Type type = param.type();
1804+
DotName typeName = type.name();
1805+
if (typeName.equals(DotNames.STRING) || typeName.equals(DotNames.URI) || typeName.equals(DotNames.URL)
1806+
|| typeName.equals(LangChain4jDotNames.AUDIO)) {
1807+
return;
1808+
}
1809+
throw new IllegalArgumentException("Unhandled @AudioUrl type '" + type.name() + "'");
17811810
}
17821811

17831812
private void validatePdfUrlParam(MethodParameterInfo param) {

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import dev.langchain4j.service.tool.ToolProvider;
3030
import dev.langchain4j.web.search.WebSearchEngine;
3131
import dev.langchain4j.web.search.WebSearchTool;
32+
import io.quarkiverse.langchain4j.AudioUrl;
3233
import io.quarkiverse.langchain4j.CreatedAware;
3334
import io.quarkiverse.langchain4j.ImageUrl;
3435
import io.quarkiverse.langchain4j.ModelName;
@@ -57,6 +58,7 @@ public class LangChain4jDotNames {
5758
public static final DotName AI_MESSAGE = DotName.createSimple(AiMessage.class);
5859
static final DotName USER_NAME = DotName.createSimple(UserName.class);
5960
static final DotName IMAGE_URL = DotName.createSimple(ImageUrl.class);
61+
static final DotName AUDIO_URL = DotName.createSimple(AudioUrl.class);
6062
static final DotName PDF_URL = DotName.createSimple(PdfUrl.class);
6163
static final DotName MODERATE = DotName.createSimple(Moderate.class);
6264
static final DotName MEMORY_ID = DotName.createSimple(MemoryId.class);
@@ -113,6 +115,7 @@ public class LangChain4jDotNames {
113115
static final DotName WEB_SEARCH_TOOL = DotName.createSimple(WebSearchTool.class);
114116
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
115117
static final DotName IMAGE = DotName.createSimple(Image.class);
118+
static final DotName AUDIO = DotName.createSimple(dev.langchain4j.data.audio.Audio.class);
116119
static final DotName PDF_FILE = DotName.createSimple(PdfFile.class);
117120
static final DotName RESULT = DotName.createSimple(Result.class);
118121
static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProvider.class);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.quarkiverse.langchain4j;
2+
3+
import static java.lang.annotation.ElementType.PARAMETER;
4+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
5+
6+
import java.lang.annotation.Retention;
7+
import java.lang.annotation.Target;
8+
9+
/**
10+
* This annotation is useful when an AiService is meant to describe an image as the value of the method parameter annotated
11+
* with @ImageUrl
12+
* will be used as an {@link dev.langchain4j.data.message.AudioContent}.
13+
* <p>
14+
* <p>
15+
* The following code contains an example of how this can be used:
16+
*
17+
* <pre>
18+
* {@code
19+
* @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
20+
* public interface AudioDescriber {
21+
*
22+
* @UserMessage("Describe the audio")
23+
* Report describe(@AudioUrl String url);
24+
* }
25+
* </pre>
26+
*
27+
* There can be at most one instance of {@code AudioUrl} per method and the supported types are the following:
28+
* <ul>
29+
* <li>String</li>
30+
* <li>URL</li>
31+
* <li>URI</li>
32+
* <li>dev.langchain4j.data.audio.Audio</li>
33+
* </ul>
34+
*
35+
*/
36+
@Retention(RUNTIME)
37+
@Target({ PARAMETER })
38+
public @interface AudioUrl {
39+
40+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/ImageUrl.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
* <pre>
1818
* {@code
1919
* @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
20-
* public interface PdfDescriber {
20+
* public interface ImageDescriber {
2121
*
2222
* @UserMessage("This is image was reported on a GitHub issue. If this is a snippet of Java code, please respond"
2323
* + " with only the Java code. If it is not, respond with 'NOT AN IMAGE'")
2424
* Report describe(@ImageUrl String url);
2525
* }
26-
* }
2726
* </pre>
2827
*
2928
* There can be at most one instance of {@code ImageUrl} per method and the supported types are the following:

core/runtime/src/main/java/io/quarkiverse/langchain4j/PdfUrl.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
* <pre>
1818
* {@code
1919
* @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
20-
* public interface ImageDescriber {
20+
* public interface PdfDescriber {
2121
*
2222
* @UserMessage("Analyze the following content")
2323
* String describe(@PdfUrl String url);
2424
* }
25-
* }
2625
* </pre>
2726
*
2827
* There can be at most one instance of {@code ImageUrl} per method and the supported types are the following:

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,22 @@ public record UserMessageInfo(Optional<TemplateInfo> template,
266266
Optional<Integer> paramPosition,
267267
Optional<Integer> userNameParamPosition,
268268
Optional<Integer> imageParamPosition,
269+
Optional<Integer> audioParamPosition,
269270
Optional<Integer> pdfParamPosition) {
270271

271272
public static UserMessageInfo fromMethodParam(int paramPosition, Optional<Integer> userNameParamPosition,
272-
Optional<Integer> imageParamPosition, Optional<Integer> pdfParamPosition) {
273+
Optional<Integer> imageParamPosition, Optional<Integer> audioParamPosition,
274+
Optional<Integer> pdfParamPosition) {
273275
return new UserMessageInfo(Optional.empty(), Optional.of(paramPosition),
274-
userNameParamPosition, imageParamPosition, pdfParamPosition);
276+
userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition);
275277
}
276278

277279
public static UserMessageInfo fromTemplate(TemplateInfo templateInfo, Optional<Integer> userNameParamPosition,
278280
Optional<Integer> imageUrlParamPosition,
281+
Optional<Integer> audioParamPosition,
279282
Optional<Integer> pdfParamPosition) {
280283
return new UserMessageInfo(Optional.of(templateInfo), Optional.empty(), userNameParamPosition,
281-
imageUrlParamPosition, pdfParamPosition);
284+
imageUrlParamPosition, audioParamPosition, pdfParamPosition);
282285
}
283286
}
284287

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838

3939
import dev.langchain4j.agent.tool.ToolExecutionRequest;
4040
import dev.langchain4j.agent.tool.ToolSpecification;
41+
import dev.langchain4j.data.audio.Audio;
4142
import dev.langchain4j.data.image.Image;
4243
import dev.langchain4j.data.message.AiMessage;
44+
import dev.langchain4j.data.message.AudioContent;
4345
import dev.langchain4j.data.message.ChatMessage;
4446
import dev.langchain4j.data.message.ImageContent;
4547
import dev.langchain4j.data.message.PdfFileContent;
@@ -727,6 +729,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
727729

728730
String userName = null;
729731
ImageContent imageContent = null;
732+
AudioContent audioContent = null;
730733
PdfFileContent pdfFileContent = null;
731734
if (userMessageInfo.userNameParamPosition().isPresent()) {
732735
userName = methodArgs[userMessageInfo.userNameParamPosition().get()]
@@ -752,6 +755,26 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
752755
+ createInfo.getMethodName());
753756
}
754757
}
758+
if (userMessageInfo.audioParamPosition().isPresent()) {
759+
Object audioParamValue = methodArgs[userMessageInfo.audioParamPosition().get()];
760+
if (audioParamValue instanceof String s) {
761+
audioContent = AudioContent.from(s);
762+
} else if (audioParamValue instanceof URI u) {
763+
audioContent = AudioContent.from(u);
764+
} else if (audioParamValue instanceof URL u) {
765+
try {
766+
audioContent = AudioContent.from(u.toURI());
767+
} catch (URISyntaxException e) {
768+
throw new RuntimeException(e);
769+
}
770+
} else if (audioParamValue instanceof Audio a) {
771+
audioContent = AudioContent.from(a);
772+
} else {
773+
throw new IllegalStateException("Unsupported parameter type '" + audioParamValue.getClass()
774+
+ "' annotated with @AudioUrl. Offending AiService is '" + createInfo.getInterfaceName() + "#"
775+
+ createInfo.getMethodName());
776+
}
777+
}
755778
if (userMessageInfo.pdfParamPosition().isPresent()) {
756779
Object pdfParamValue = methodArgs[userMessageInfo.pdfParamPosition().get()];
757780
if (pdfParamValue instanceof String s) {
@@ -804,7 +827,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
804827
}
805828

806829
Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables);
807-
return createUserMessage(userName, imageContent, pdfFileContent, prompt.text());
830+
return createUserMessage(userName, imageContent, audioContent, pdfFileContent, prompt.text());
808831

809832
} else if (userMessageInfo.paramPosition().isPresent()) {
810833
Integer paramIndex = userMessageInfo.paramPosition().get();
@@ -818,6 +841,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
818841

819842
String text = toString(argValue);
820843
return createUserMessage(userName, imageContent,
844+
audioContent,
821845
pdfFileContent, text.concat(supportsJsonSchema || !createInfo.getResponseSchemaInfo().enabled() ? ""
822846
: createInfo.getResponseSchemaInfo().outputFormatInstructions()));
823847
} else {
@@ -843,13 +867,17 @@ private static Map<String, Object> getTemplateVariables(Object[] methodArgs,
843867
return variables;
844868
}
845869

846-
private static UserMessage createUserMessage(String name, ImageContent imageContent, PdfFileContent pdfFileContent,
870+
private static UserMessage createUserMessage(String name, ImageContent imageContent, AudioContent audioContent,
871+
PdfFileContent pdfFileContent,
847872
String text) {
848873
List<dev.langchain4j.data.message.Content> contents = new ArrayList<>();
849874
contents.add(TextContent.from(text));
850875
if (imageContent != null) {
851876
contents.add(imageContent);
852877
}
878+
if (audioContent != null) {
879+
contents.add(audioContent);
880+
}
853881
if (pdfFileContent != null) {
854882
contents.add(pdfFileContent);
855883
}

0 commit comments

Comments
 (0)