Skip to content

Commit eebf266

Browse files
authored
Merge pull request #1725 from ulisse1996/fix/fix-java-time-schema-generation
Handle java time types in json schema generation
2 parents f914cd4 + 7f10784 commit eebf266

File tree

5 files changed

+158
-2
lines changed

5 files changed

+158
-2
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
import java.math.BigDecimal;
44
import java.math.BigInteger;
5+
import java.time.Instant;
6+
import java.time.LocalDate;
7+
import java.time.LocalDateTime;
8+
import java.time.LocalTime;
9+
import java.time.OffsetDateTime;
10+
import java.time.OffsetTime;
11+
import java.time.Year;
12+
import java.time.YearMonth;
513
import java.util.List;
614
import java.util.Set;
715
import java.util.concurrent.CompletionStage;
@@ -56,6 +64,15 @@ public class DotNames {
5664
public static final DotName COMPLETION_STAGE = DotName.createSimple(CompletionStage.class);
5765
public static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class);
5866

67+
public static final DotName INSTANT = DotName.createSimple(Instant.class);
68+
public static final DotName LOCAL_DATE = DotName.createSimple(LocalDate.class);
69+
public static final DotName LOCAL_DATE_TIME = DotName.createSimple(LocalDateTime.class);
70+
public static final DotName LOCAL_TIME = DotName.createSimple(LocalTime.class);
71+
public static final DotName OFFSET_DATE_TIME = DotName.createSimple(OffsetDateTime.class);
72+
public static final DotName OFFSET_TIME = DotName.createSimple(OffsetTime.class);
73+
public static final DotName YEAR = DotName.createSimple(Year.class);
74+
public static final DotName YEAR_MONTH = DotName.createSimple(YearMonth.class);
75+
5976
public static final DotName OBJECT = DotName.createSimple(Object.class.getName());
6077
public static final DotName RECORD = DotName.createSimple(Record.class);
6178
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ public class ToolProcessor {
101101

102102
private static final Logger log = Logger.getLogger(ToolProcessor.class);
103103

104+
private static final List<DotName> JAVA_TIME_NAMES = List.of(
105+
DotNames.INSTANT, DotNames.LOCAL_DATE, DotNames.LOCAL_DATE_TIME, DotNames.LOCAL_TIME,
106+
DotNames.OFFSET_DATE_TIME, DotNames.OFFSET_TIME, DotNames.YEAR, DotNames.YEAR_MONTH);
107+
104108
@BuildStep
105109
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
106110
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
@@ -557,6 +561,11 @@ private JsonSchemaElement toJsonSchemaElement(Type type, IndexView index, String
557561
return JsonNumberSchema.builder().description(description).build();
558562
}
559563

564+
if (JAVA_TIME_NAMES.stream().anyMatch(typeName::equals)) {
565+
// TODO In the future we can implement parsing validation with patterns
566+
return JsonStringSchema.builder().description(description).build();
567+
}
568+
560569
// TODO something else?
561570
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
562571
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
@@ -587,7 +596,12 @@ private JsonSchemaElement toJsonSchemaElement(Type type, IndexView index, String
587596
.description(Optional.ofNullable(description).orElseGet(() -> descriptionFrom(type)));
588597

589598
ClassInfo targetClass = index.getClassByName(type.name());
590-
buildSchema(index, builder, targetClass);
599+
600+
if (targetClass != null) {
601+
buildSchema(index, builder, targetClass);
602+
} else {
603+
log.warnf("The type '%s' could not be accessed from the index", type.name());
604+
}
591605

592606
return builder.build();
593607
}
@@ -602,7 +616,7 @@ private void buildSchema(IndexView index, Builder builder, ClassInfo targetClass
602616
buildSchema(index, builder, superClass);
603617
}
604618
}
605-
Optional.ofNullable(targetClass)
619+
Optional.of(targetClass)
606620
.map(ClassInfo::fields)
607621
.orElseGet(List::of)
608622
.forEach(field -> {

integration-tests/tools/src/main/java/org/acme/tools/AiService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@
88
public interface AiService {
99
@ToolBox(Calculator.class)
1010
public String chat(@UserMessage String message);
11+
12+
@ToolBox(Calendar.class)
13+
String calendarChat(@UserMessage String message);
1114
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package org.acme.tools;
2+
3+
import java.time.*;
4+
import java.time.format.DateTimeFormatter;
5+
6+
import jakarta.enterprise.context.ApplicationScoped;
7+
8+
import dev.langchain4j.agent.tool.Tool;
9+
import io.smallrye.common.annotation.Blocking;
10+
11+
@ApplicationScoped
12+
public class Calendar {
13+
14+
@Tool
15+
@Blocking
16+
public String instant(Instant variable) {
17+
return DateTimeFormatter.ISO_INSTANT.format(variable);
18+
}
19+
20+
@Tool
21+
@Blocking
22+
public String date(LocalDate variable) {
23+
return variable.format(DateTimeFormatter.ISO_DATE);
24+
}
25+
26+
@Tool
27+
@Blocking
28+
public String dateTime(LocalDateTime variable) {
29+
return variable.format(DateTimeFormatter.ISO_DATE_TIME);
30+
}
31+
32+
@Tool
33+
@Blocking
34+
public String time(LocalTime variable) {
35+
return variable.format(DateTimeFormatter.ISO_LOCAL_TIME);
36+
}
37+
38+
@Tool
39+
@Blocking
40+
public String offsetDateTime(OffsetDateTime variable) {
41+
return variable.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME);
42+
}
43+
44+
@Tool
45+
@Blocking
46+
public String offsetTime(OffsetTime variable) {
47+
return variable.format(DateTimeFormatter.ISO_OFFSET_TIME);
48+
}
49+
50+
@Tool
51+
@Blocking
52+
public String year(Year variable) {
53+
return variable.toString();
54+
}
55+
56+
@Tool
57+
@Blocking
58+
public String yearMonth(YearMonth variable) {
59+
return variable.toString();
60+
}
61+
62+
@Tool
63+
@Blocking
64+
public String period(Period variable) {
65+
return variable.toString();
66+
}
67+
}

integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,25 @@
22

33
import static org.junit.jupiter.api.Assertions.assertEquals;
44

5+
import java.time.Instant;
6+
import java.time.LocalDate;
7+
import java.time.LocalDateTime;
8+
import java.time.LocalTime;
9+
import java.time.OffsetDateTime;
10+
import java.time.OffsetTime;
11+
import java.time.Period;
12+
import java.time.Year;
13+
import java.time.YearMonth;
14+
import java.util.stream.Stream;
15+
516
import jakarta.inject.Inject;
617

718
import org.junit.jupiter.api.Test;
819
import org.junit.jupiter.api.condition.EnabledForJreRange;
920
import org.junit.jupiter.api.condition.JRE;
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.Arguments;
23+
import org.junit.jupiter.params.provider.MethodSource;
1024
import org.mockito.Mockito;
1125

1226
import dev.langchain4j.agent.tool.ToolExecutionRequest;
@@ -68,6 +82,47 @@ void virtualThreadSum() {
6882
assertEquals("The result is 2", aiService.chat("Execute 1 + 1"));
6983
}
7084

85+
@ParameterizedTest(name = "blocking {1}")
86+
@MethodSource("getJavaTimeChats")
87+
void blockingJavaTime(String value, String methodName) {
88+
var toolExecution = createCalendar(methodName, value);
89+
90+
Mockito.when(model.chat(Mockito.any(ChatRequest.class)))
91+
.thenReturn(
92+
ChatResponse.builder().aiMessage(AiMessage.from(toolExecution)).tokenUsage(new TokenUsage(1)).build(),
93+
ChatResponse.builder().aiMessage(AiMessage.from("Got %s".formatted(value)))
94+
.tokenUsage(new TokenUsage(1))
95+
.build());
96+
97+
assertEquals("Got %s".formatted(value), aiService.calendarChat("Execute %s".formatted(methodName)));
98+
}
99+
100+
protected static Stream<Arguments> getJavaTimeChats() {
101+
return Stream.of(
102+
Arguments.of(Instant.now().toString(), "instant"),
103+
Arguments.of(LocalDate.now().toString(), "date"),
104+
Arguments.of(LocalDateTime.now().toString(), "dateTime"),
105+
Arguments.of(LocalTime.now().toString(), "time"),
106+
Arguments.of(OffsetDateTime.now().toString(), "offsetDateTime"),
107+
Arguments.of(OffsetTime.now().toString(), "offsetTime"),
108+
Arguments.of(Year.now().toString(), "year"),
109+
Arguments.of(YearMonth.now().toString(), "yearMonth"),
110+
Arguments.of(Period.ofDays(3).toString(), "period"));
111+
}
112+
113+
private ToolExecutionRequest createCalendar(String methodName, String value) {
114+
return ToolExecutionRequest.builder()
115+
.id("1")
116+
.name(methodName)
117+
.arguments(
118+
"""
119+
{
120+
"variable": "%s"
121+
}
122+
""".formatted(value))
123+
.build();
124+
}
125+
71126
private ToolExecutionRequest create(String methodName) {
72127
return ToolExecutionRequest.builder()
73128
.id("1")

0 commit comments

Comments
 (0)