88import java .util .Optional ;
99import java .util .UUID ;
1010
11+ import org .jboss .logging .Logger ;
12+
1113import com .github .tjake .jlama .model .AbstractModel ;
1214import com .github .tjake .jlama .model .functions .Generator ;
1315import com .github .tjake .jlama .safetensors .DType ;
1416import com .github .tjake .jlama .safetensors .prompt .PromptContext ;
1517import com .github .tjake .jlama .safetensors .prompt .PromptSupport ;
16- import com .github .tjake .jlama .safetensors .prompt .Tool ;
1718import com .github .tjake .jlama .safetensors .prompt .ToolCall ;
1819import com .github .tjake .jlama .safetensors .prompt .ToolResult ;
1920import com .github .tjake .jlama .util .JsonSupport ;
3536import dev .langchain4j .model .output .TokenUsage ;
3637
3738public class JlamaChatModel implements ChatLanguageModel {
39+
40+ private static final Logger log = Logger .getLogger (JlamaChatModel .class );
41+
3842 private final AbstractModel model ;
3943 private final Float temperature ;
4044 private final Integer maxTokens ;
45+ private final Boolean logRequests ;
46+ private final Boolean logResponses ;
4147
4248 public JlamaChatModel (JlamaChatModelBuilder builder ) {
4349
@@ -46,21 +52,27 @@ public JlamaChatModel(JlamaChatModelBuilder builder) {
4652 .withRetry (() -> registry .downloadModel (builder .modelName , Optional .ofNullable (builder .authToken )), 3 );
4753
4854 JlamaModel .Loader loader = jlamaModel .loader ();
49- if (builder .quantizeModelAtRuntime != null && builder .quantizeModelAtRuntime )
55+ if (builder .quantizeModelAtRuntime != null && builder .quantizeModelAtRuntime ) {
5056 loader = loader .quantized ();
57+ }
5158
52- if (builder .workingQuantizedType != null )
59+ if (builder .workingQuantizedType != null ) {
5360 loader = loader .workingQuantizationType (builder .workingQuantizedType );
61+ }
5462
55- if (builder .threadCount != null )
63+ if (builder .threadCount != null ) {
5664 loader = loader .threadCount (builder .threadCount );
65+ }
5766
58- if (builder .workingDirectory != null )
67+ if (builder .workingDirectory != null ) {
5968 loader = loader .workingDirectory (builder .workingDirectory );
69+ }
6070
6171 this .model = loader .load ();
6272 this .temperature = builder .temperature == null ? 0.3f : builder .temperature ;
6373 this .maxTokens = builder .maxTokens == null ? model .getConfig ().contextLength : builder .maxTokens ;
74+ this .logRequests = builder .logRequests != null && builder .logRequests ;
75+ this .logResponses = builder .logResponses != null && builder .logResponses ;
6476 }
6577
6678 public static JlamaChatModelBuilder builder () {
@@ -74,9 +86,29 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
7486
7587 @ Override
7688 public Response <AiMessage > generate (List <ChatMessage > messages , List <ToolSpecification > toolSpecifications ) {
77- if (model .promptSupport ().isEmpty ())
89+ if (model .promptSupport ().isEmpty ()) {
7890 throw new UnsupportedOperationException ("This model does not support chat generation" );
91+ }
92+
93+ if (logRequests ) {
94+ log .info ("Request: " + messages );
95+ }
96+
97+ PromptSupport .Builder promptBuilder = promptBuilder (messages );
98+ Generator .Response r = model .generate (UUID .randomUUID (), promptContext (promptBuilder , toolSpecifications ), temperature ,
99+ maxTokens , (token , time ) -> {
100+ });
101+ Response <AiMessage > aiResponse = Response .from (aiMessageForResponse (r ),
102+ new TokenUsage (r .promptTokens , r .generatedTokens ), toFinishReason (r .finishReason ));
79103
104+ if (logResponses ) {
105+ log .info ("Response: " + aiResponse );
106+ }
107+
108+ return aiResponse ;
109+ }
110+
111+ private PromptSupport .Builder promptBuilder (List <ChatMessage > messages ) {
80112 PromptSupport .Builder promptBuilder = model .promptSupport ().get ().builder ();
81113
82114 for (ChatMessage message : messages ) {
@@ -86,17 +118,18 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
86118 StringBuilder finalMessage = new StringBuilder ();
87119 UserMessage userMessage = (UserMessage ) message ;
88120 for (Content content : userMessage .contents ()) {
89- if (content .type () != ContentType .TEXT )
121+ if (content .type () != ContentType .TEXT ) {
90122 throw new UnsupportedOperationException ("Unsupported content type: " + content .type ());
91-
123+ }
92124 finalMessage .append (((TextContent ) content ).text ());
93125 }
94126 promptBuilder .addUserMessage (finalMessage .toString ());
95127 }
96128 case AI -> {
97129 AiMessage aiMessage = (AiMessage ) message ;
98- if (aiMessage .text () != null )
130+ if (aiMessage .text () != null ) {
99131 promptBuilder .addAssistantMessage (aiMessage .text ());
132+ }
100133
101134 if (aiMessage .hasToolExecutionRequests ())
102135 for (ToolExecutionRequest toolExecutionRequest : aiMessage .toolExecutionRequests ()) {
@@ -113,26 +146,26 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
113146 default -> throw new IllegalArgumentException ("Unsupported message type: " + message .type ());
114147 }
115148 }
149+ return promptBuilder ;
150+ }
116151
117- List <Tool > tools = toolSpecifications .stream ().map (JlamaModel ::toTool ).toList ();
118-
119- PromptContext promptContext = tools .isEmpty () ? promptBuilder .build () : promptBuilder .build (tools );
120- Generator .Response r = model .generate (UUID .randomUUID (), promptContext , temperature , maxTokens , (token , time ) -> {
121- });
152+ private PromptContext promptContext (PromptSupport .Builder promptBuilder , List <ToolSpecification > toolSpecifications ) {
153+ return toolSpecifications .isEmpty () ? promptBuilder .build ()
154+ : promptBuilder .build (toolSpecifications .stream ().map (JlamaModel ::toTool ).toList ());
155+ }
122156
157+ private AiMessage aiMessageForResponse (Generator .Response r ) {
123158 if (r .finishReason == Generator .FinishReason .TOOL_CALL ) {
124159 List <ToolExecutionRequest > toolCalls = r .toolCalls .stream ().map (f -> ToolExecutionRequest .builder ()
125160 .name (f .getName ())
126161 .id (f .getId ())
127162 .arguments (JsonSupport .toJson (f .getParameters ()))
128163 .build ()).toList ();
129164
130- return Response .from (AiMessage .from (toolCalls ), new TokenUsage (r .promptTokens , r .generatedTokens ),
131- toFinishReason (r .finishReason ));
165+ return AiMessage .from (toolCalls );
132166 }
133167
134- return Response .from (AiMessage .from (r .responseText ), new TokenUsage (r .promptTokens , r .generatedTokens ),
135- toFinishReason (r .finishReason ));
168+ return AiMessage .from (r .responseText );
136169 }
137170
138171 @ Override
@@ -152,6 +185,8 @@ public static class JlamaChatModelBuilder {
152185 private DType workingQuantizedType ;
153186 private Float temperature ;
154187 private Integer maxTokens ;
188+ private Boolean logRequests ;
189+ private Boolean logResponses ;
155190
156191 public JlamaChatModelBuilder modelCachePath (Optional <Path > modelCachePath ) {
157192 this .modelCachePath = modelCachePath ;
@@ -198,6 +233,16 @@ public JlamaChatModelBuilder maxTokens(Integer maxTokens) {
198233 return this ;
199234 }
200235
236+ public JlamaChatModelBuilder logRequests (Boolean logRequests ) {
237+ this .logRequests = logRequests ;
238+ return this ;
239+ }
240+
241+ public JlamaChatModelBuilder logResponses (Boolean logResponses ) {
242+ this .logResponses = logResponses ;
243+ return this ;
244+ }
245+
201246 public JlamaChatModel build () {
202247 return new JlamaChatModel (this );
203248 }
0 commit comments