diff --git a/core/build.gradle.kts b/core/build.gradle.kts index de1530459..3168410e3 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -106,7 +106,9 @@ kotlin { } val jvmTest by getting { dependencies { - implementation(libs.kotest.junit5) + implementation(libs.ollama.testcontainers) + implementation(libs.junit.jupiter.api) + implementation(libs.junit.jupiter.engine) } } val linuxX64Main by getting { diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt index 81276970b..51cc74d16 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt @@ -6,6 +6,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestMo import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.ToolCallStrategy import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KClass import kotlin.reflect.KType @@ -108,10 +109,11 @@ sealed interface AI { prompt: String, target: KType = typeOf(), model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125, + toolCallStrategy: ToolCallStrategy = ToolCallStrategy.Supported, config: Config = Config(), api: Chat = OpenAI(config).chat, conversation: Conversation = Conversation() - ): A = chat(Prompt(model, prompt), target, config, api, conversation) + ): A = chat(Prompt(model, toolCallStrategy, prompt), target, config, api, conversation) @AiDsl suspend inline operator fun invoke( diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt index 4d3b1b9cd..70afeb089 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt @@ -7,12 +7,16 @@ import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.llm.StreamedFunction import com.xebia.functional.xef.llm.models.modelType import com.xebia.functional.xef.llm.prompt +import com.xebia.functional.xef.llm.promptMessage import com.xebia.functional.xef.llm.promptStreaming import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.ToolCallStrategy import kotlin.reflect.KClass import kotlin.reflect.KType import kotlin.reflect.typeOf import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.mapNotNull +import kotlinx.coroutines.flow.single import kotlinx.serialization.* import kotlinx.serialization.descriptors.* @@ -29,7 +33,18 @@ data class DefaultAI( @Serializable data class Value(val value: A) private suspend fun runWithSerializer(prompt: Prompt, serializer: KSerializer): B = - api.prompt(prompt, conversation, serializer) + when (prompt.toolCallStrategy) { + ToolCallStrategy.Supported -> api.prompt(prompt, conversation, serializer) + else -> + runStreamingWithFunctionSerializer(prompt, serializer) + .mapNotNull { + when (it) { + is StreamedFunction.Property -> null + is StreamedFunction.Result -> it.value + } + } + .single() + } private fun runStreamingWithStringSerializer(prompt: Prompt): Flow = api.promptStreaming(prompt, conversation) @@ -49,6 +64,7 @@ data class DefaultAI( suspend operator fun invoke(prompt: Prompt): A { val serializer = serializer() return when (serializer.descriptor.kind) { + PrimitiveKind.STRING -> api.promptMessage(prompt, conversation) as A SerialKind.ENUM -> { runWithEnumSingleTokenSerializer(serializer, prompt) } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index 60ab52f26..6a9aefafb 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -7,6 +7,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionResponseC import com.xebia.functional.xef.AIError import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation +import com.xebia.functional.xef.llm.PromptCalculator.adaptPromptToConversationAndModel import com.xebia.functional.xef.llm.models.MessageWithUsage import com.xebia.functional.xef.llm.models.MessagesUsage import com.xebia.functional.xef.llm.models.MessagesWithUsage @@ -18,7 +19,7 @@ import kotlinx.coroutines.flow.* @AiDsl fun Chat.promptStreaming(prompt: Prompt, scope: Conversation = Conversation()): Flow = flow { - val messagesForRequestPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope) + val messagesForRequestPrompt = prompt.adaptPromptToConversationAndModel(scope) val request = CreateChatCompletionRequest( @@ -42,14 +43,13 @@ fun Chat.promptStreaming(prompt: Prompt, scope: Conversation = Conversation()): } content } - .onEach { emit(it) } .onCompletion { val aiResponseMessage = PromptBuilder.assistant(buffer.toString()) val newMessages = prompt.messages + listOf(aiResponseMessage) newMessages.addToMemory(scope, prompt.configuration.messagePolicy.addMessagesToConversation) buffer.clear() } - .collect() + .collect { emit(it) } } @AiDsl @@ -88,7 +88,7 @@ private suspend fun Chat.promptResponse( ): Pair, CreateChatCompletionResponse> = scope.metric.promptSpan(prompt) { val promptMemories: List = prompt.messages.toMemory(scope) - val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope) + val adaptedPrompt = prompt.adaptPromptToConversationAndModel(scope) adaptedPrompt.addMetrics(scope) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index e2d081397..43772e593 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -5,8 +5,11 @@ import arrow.core.raise.catch import com.xebia.functional.openai.generated.api.Chat import com.xebia.functional.openai.generated.model.* import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.Config import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation +import com.xebia.functional.xef.conversation.Description +import com.xebia.functional.xef.llm.PromptCalculator.adaptPromptToConversationAndModel import com.xebia.functional.xef.llm.models.functions.buildJsonSchema import com.xebia.functional.xef.prompt.Prompt import io.github.oshai.kotlinlogging.KotlinLogging @@ -20,14 +23,16 @@ import kotlinx.serialization.json.* @OptIn(ExperimentalSerializationApi::class) fun chatFunction(descriptor: SerialDescriptor): FunctionObject { val fnName = descriptor.serialName.substringAfterLast(".") - return chatFunction(fnName, buildJsonSchema(descriptor)) + val description = + descriptor.annotations.firstOrNull { it is Description }?.let { it as Description }?.value + return chatFunction(fnName, description, buildJsonSchema(descriptor)) } fun chatFunctions(descriptors: List): List = descriptors.map(::chatFunction) -fun chatFunction(fnName: String, schema: JsonObject): FunctionObject = - FunctionObject(fnName, "Generated function for $fnName", schema) +fun chatFunction(fnName: String, description: String?, schema: JsonObject): FunctionObject = + FunctionObject(fnName, description ?: "Generated function for $fnName", schema) @AiDsl suspend fun Chat.prompt( @@ -36,7 +41,7 @@ suspend fun Chat.prompt( serializer: KSerializer, ): A = prompt(prompt, scope, chatFunctions(listOf(serializer.descriptor))) { call -> - Json.decodeFromString(serializer, call.arguments) + Config.DEFAULT.json.decodeFromString(serializer, call.arguments) } @OptIn(ExperimentalSerializationApi::class) @@ -49,7 +54,8 @@ suspend fun Chat.prompt( ): A = prompt(prompt, scope, chatFunctions(descriptors)) { call -> // adds a `type` field with the call.functionName serial name equivalent to the call arguments - val jsonWithDiscriminator = Json.decodeFromString(JsonElement.serializer(), call.arguments) + val jsonWithDiscriminator = + Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments) val descriptor = descriptors.firstOrNull { it.serialName.endsWith(call.functionName) } ?: error("No descriptor found for ${call.functionName}") @@ -57,7 +63,7 @@ suspend fun Chat.prompt( JsonObject( jsonWithDiscriminator.jsonObject + ("type" to JsonPrimitive(descriptor.serialName)) ) - Json.decodeFromString(serializer, Json.encodeToString(newJson)) + Config.DEFAULT.json.decodeFromString(serializer, Config.DEFAULT.json.encodeToString(newJson)) } @AiDsl @@ -67,7 +73,7 @@ fun Chat.promptStreaming( serializer: KSerializer, ): Flow> = promptStreaming(prompt, scope, chatFunction(serializer.descriptor)) { json -> - Json.decodeFromString(serializer, json) + Config.DEFAULT.json.decodeFromString(serializer, json) } @AiDsl @@ -79,8 +85,7 @@ suspend fun Chat.prompt( ): A = scope.metric.promptSpan(prompt) { val promptWithFunctions = prompt.copy(functions = functions) - val adaptedPrompt = - PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope) + val adaptedPrompt = promptWithFunctions.adaptPromptToConversationAndModel(scope) adaptedPrompt.addMetrics(scope) val request = createChatCompletionRequest(adaptedPrompt) tryDeserialize(serializer, promptWithFunctions.configuration.maxDeserializationAttempts) { @@ -139,7 +144,7 @@ fun Chat.promptStreaming( serializer: (json: String) -> A, ): Flow> = flow { val promptWithFunctions = prompt.copy(functions = listOf(function)) - val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope) + val adaptedPrompt = promptWithFunctions.adaptPromptToConversationAndModel(scope) val request = createChatCompletionRequest(adaptedPrompt).copy(stream = true) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt index b50051eee..8c9224026 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt @@ -13,10 +13,10 @@ import com.xebia.functional.xef.store.Memory internal object PromptCalculator { - suspend fun adaptPromptToConversationAndModel(prompt: Prompt, scope: Conversation): Prompt = - when (prompt.configuration.messagePolicy.addMessagesFromConversation) { - MessagesFromHistory.ALL -> adaptPromptFromConversation(prompt, scope) - MessagesFromHistory.NONE -> prompt + suspend fun Prompt.adaptPromptToConversationAndModel(scope: Conversation): Prompt = + when (configuration.messagePolicy.addMessagesFromConversation) { + MessagesFromHistory.ALL -> adaptPromptFromConversation(this, scope) + MessagesFromHistory.NONE -> this } private suspend fun adaptPromptFromConversation(prompt: Prompt, scope: Conversation): Prompt { diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt index cd14ebef4..482baedd0 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt @@ -3,14 +3,16 @@ package com.xebia.functional.xef.llm import com.xebia.functional.openai.generated.api.Chat import com.xebia.functional.openai.generated.model.* import com.xebia.functional.xef.conversation.Conversation -import com.xebia.functional.xef.llm.StreamedFunction.Companion.PropertyType.* +import com.xebia.functional.xef.llm.streaming.FunctionCallFormat +import com.xebia.functional.xef.llm.streaming.JsonSupport +import com.xebia.functional.xef.llm.streaming.XmlSupport import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.PromptBuilder +import com.xebia.functional.xef.prompt.PromptBuilder.Companion.user +import com.xebia.functional.xef.prompt.ToolCallStrategy +import io.ktor.client.request.* import kotlin.jvm.JvmSynthetic -import kotlinx.coroutines.flow.FlowCollector -import kotlinx.coroutines.flow.onCompletion -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.json.* +import kotlinx.coroutines.flow.* sealed class StreamedFunction { data class Property(val path: List, val name: String, val value: String) : @@ -18,6 +20,13 @@ sealed class StreamedFunction { data class Result(val value: A) : StreamedFunction() + fun print() { + when (this) { + is Property -> println("Property: $name = $value") + is Result -> println("Result: $value") + } + } + companion object { /** @@ -43,7 +52,7 @@ sealed class StreamedFunction { prompt: Prompt, request: CreateChatCompletionRequest, scope: Conversation, - serializer: (json: String) -> A, + serializer: (output: String) -> A, function: FunctionObject ) { val messages = mutableListOf() @@ -58,18 +67,14 @@ sealed class StreamedFunction { // we extract the expected JSON schema before the LLM replies val schema = function.parameters // we create an example from the schema from which we can expect and infer the paths - // as the LLM is sending us chunks with malformed JSON + // as the LLM is sending us chunks with malformed JSON or XML + // val format : FunctionCallFormat = JsonSupport if (schema != null) { - val example = createExampleFromSchema(schema) - chat - .createChatCompletionStream(request) - .onCompletion { - val newMessages = prompt.messages + messages - newMessages.addToMemory( - scope, - prompt.configuration.messagePolicy.addMessagesToConversation - ) - } + val format = functionCallFormat(prompt) + val stream = functionCallStream(prompt, chat, request) + val example = format.createExampleFromSchema(schema) + stream + .onCompletion { addMessagesToMemory(prompt, messages, scope) } .collect { responseChunk -> // Each chunk is emitted from the LLM and it will include a delta.parameters with // the function is streaming, the JSON received will be partial and usually malformed @@ -101,12 +106,13 @@ sealed class StreamedFunction { // we update the path // a change of property happens and we try to stream it streamProperty( + format, path, currentProperty, functionCall.arguments, streamedProperties ) - path = findPropertyPath(example, currentArg) ?: listOf(currentArg) + path = format.findPropertyPath(example, currentArg) ?: listOf(currentArg) } // update the current property being evaluated currentProperty = currentArg @@ -115,50 +121,67 @@ sealed class StreamedFunction { // the stream is finished and we try to stream the last property // because the previous chunk may had a partial property whose body // may had not been fully streamed - streamProperty(path, currentProperty, functionCall.arguments, streamedProperties) + streamProperty( + format, + path, + currentProperty, + functionCall.arguments, + streamedProperties + ) } } if (finishReason != null) { // we stream the result - streamResult(functionCall, messages, serializer) + streamResult(format, functionCall, messages, serializer) } } } } } + private suspend fun addMessagesToMemory( + prompt: Prompt, + messages: MutableList, + scope: Conversation + ) { + val newMessages = prompt.messages + messages + newMessages.addToMemory(scope, prompt.configuration.messagePolicy.addMessagesToConversation) + } + + private fun functionCallStream( + prompt: Prompt, + chat: Chat, + request: CreateChatCompletionRequest + ): Flow = + when (prompt.toolCallStrategy) { + ToolCallStrategy.Supported -> chat.createChatCompletionStream(request) + ToolCallStrategy.InferJsonFromStringResponse -> + chat.createChatCompletionStreamFromStringParsing(JsonSupport, request) + ToolCallStrategy.InferXmlFromStringResponse -> + chat.createChatCompletionStreamFromStringParsing(XmlSupport, request) + } + + private fun functionCallFormat(prompt: Prompt): FunctionCallFormat = + when (prompt.toolCallStrategy) { + ToolCallStrategy.Supported -> JsonSupport + ToolCallStrategy.InferJsonFromStringResponse -> JsonSupport + ToolCallStrategy.InferXmlFromStringResponse -> XmlSupport + } + private suspend fun FlowCollector>.streamResult( + format: FunctionCallFormat, functionCall: ChatCompletionMessageToolCallFunction, messages: MutableList, - serializer: (json: String) -> A + serializer: (output: String) -> A ) { - val arguments = functionCall.arguments + val arguments = format.cleanArguments(functionCall) + val jsonArguments = format.argumentsToJsonString(arguments) messages.add(PromptBuilder.assistant("Function call: $functionCall")) - val result = serializer(arguments) + val result = serializer(jsonArguments) // stream the result emit(Result(result)) } - /** - * The PropertyType enum represents the different types of properties that can be identified - * from JSON. These include STRING, NUMBER, BOOLEAN, ARRAY, OBJECT, NULL, and UNKNOWN. - * - * STRING: Represents a property with a string value. NUMBER: Represents a property with a - * numeric value. BOOLEAN: Represents a property with a boolean value. ARRAY: Represents a - * property that is an array of values. OBJECT: Represents a property that is an object with - * key-value pairs. NULL: Represents a property with a null value. UNKNOWN: Represents a - * property of unknown type. - */ - private enum class PropertyType { - STRING, - NUMBER, - BOOLEAN, - ARRAY, - OBJECT, - NULL, - UNKNOWN - } - /** * Streams a property * @@ -172,6 +195,7 @@ sealed class StreamedFunction { * @param streamedProperties The set of already streamed properties. */ private suspend fun FlowCollector>.streamProperty( + format: FunctionCallFormat, path: List, prop: String?, currentArgs: String?, @@ -180,20 +204,12 @@ sealed class StreamedFunction { if (prop != null && currentArgs != null) { // stream a new property try { - val remainingText = currentArgs.replace("\n", "") - val body = remainingText.substringAfterLast("\"$prop\":").trim() - // detect the type of the property - val propertyType = propertyType(body) - // extract the body of the property or if null don't report it - val detectedBody = extractBody(propertyType, body) ?: return - // repack the body as a valid JSON string - val propertyValueAsJson = repackBodyAsJsonString(propertyType, detectedBody) - if (propertyValueAsJson != null) { - val propertyValue = Json.decodeFromString(propertyValueAsJson) + val propertyValue = format.propertyValue(prop, currentArgs) + if (propertyValue != null) { // we try to extract the text value of the property // or for cases like objects that we don't want to report on // we return null - val text = textProperty(propertyValue) + val text = format.textProperty(propertyValue) if (text != null) { val streamedProperty = Property(path, prop, text) // we only stream the property if it has not been streamed before @@ -210,85 +226,6 @@ sealed class StreamedFunction { } } - /** - * Repacks the detected body as a JSON string based on the provided property type. - * - * @param propertyType The property type to determine how the body should be repacked. - * @param detectedBody The detected body to be repacked as a JSON string. - * @return The repacked body as a JSON string. - */ - private fun repackBodyAsJsonString(propertyType: PropertyType, detectedBody: String?): String? = - when (propertyType) { - STRING -> "\"$detectedBody\"" - NUMBER -> detectedBody - BOOLEAN -> detectedBody - ARRAY -> "[$detectedBody]" - OBJECT -> "{$detectedBody}" - NULL -> "null" - else -> null - } - - /** - * Extracts the body from a given input string which may contain potentially malformed json or - * partial json chunk results. - * - * @param propertyType The type of property being extracted. - * @param body The input string to extract the body from. - * @return The extracted body string, or null if the body cannot be found. - */ - private fun extractBody(propertyType: PropertyType, body: String): String? = - when (propertyType) { - STRING -> stringBody.find(body)?.groupValues?.get(1) - NUMBER -> numberBody.find(body)?.groupValues?.get(1) - BOOLEAN -> booleanBody.find(body)?.groupValues?.get(1) - ARRAY -> arrayBody.find(body)?.groupValues?.get(1) - OBJECT -> objectBody.find(body)?.groupValues?.get(1) - NULL -> nullBody.find(body)?.groupValues?.get(1) - else -> null - } - - /** - * Determines the type of a property based on a partial chnk of it's body. - * - * @param body The body of the property. - * @return The type of the property. - */ - private fun propertyType(body: String): PropertyType = - when (body.firstOrNull()) { - '"' -> STRING - in '0'..'9' -> NUMBER - 't', - 'f' -> BOOLEAN - '[' -> ARRAY - '{' -> OBJECT - 'n' -> NULL - else -> UNKNOWN - } - - private val stringBody = """\"(.*?)\"""".toRegex() - private val numberBody = "(-?\\d+(\\.\\d+)?)".toRegex() - private val booleanBody = """(true|false)""".toRegex() - private val arrayBody = """\[(.*?)\]""".toRegex() - private val objectBody = """\{(.*?)\}""".toRegex() - private val nullBody = """null""".toRegex() - - /** - * Searches for the content of the property within a given JsonElement. - * - * @param element The JsonElement to search within. - * @return The text property as a String, or null if not found. - */ - private fun textProperty(element: JsonElement): String? { - return when (element) { - // we don't report on properties holding objects since we report on the properties of the - // object - is JsonObject -> null - is JsonArray -> element.map { textProperty(it) }.joinToString(", ") - is JsonPrimitive -> element.content - is JsonNull -> "null" - } - } - private fun mergeArgumentsWithDelta( functionCall: ChatCompletionMessageToolCallFunction, functionCall0: ChatCompletionMessageToolCallChunk @@ -305,61 +242,65 @@ sealed class StreamedFunction { ?.groupValues ?.lastOrNull() - private fun findPropertyPath(jsonElement: JsonElement, targetProperty: String): List? { - return findPropertyPathTailrec(listOf(jsonElement to emptyList()), targetProperty) - } - - private tailrec fun findPropertyPathTailrec( - stack: List>>, - targetProperty: String - ): List? { - if (stack.isEmpty()) return null - - val (currentElement, currentPath) = stack.first() - val remainingStack = stack.drop(1) - - return when (currentElement) { - is JsonObject -> { - if (currentElement.containsKey(targetProperty)) { - currentPath + targetProperty - } else { - val newStack = currentElement.entries.map { it.value to (currentPath + it.key) } - findPropertyPathTailrec(remainingStack + newStack, targetProperty) + fun Chat.createChatCompletionStreamFromStringParsing( + format: FunctionCallFormat, + request: CreateChatCompletionRequest + ): Flow { + val choiceName = + request.toolChoice?.let { + when (it) { + is ChatCompletionToolChoiceOption.CaseChatCompletionNamedToolChoice -> + it.value.function.name + else -> null } } - is JsonArray -> { - val newStack = currentElement.map { it to currentPath } - findPropertyPathTailrec(remainingStack + newStack, targetProperty) - } - else -> findPropertyPathTailrec(remainingStack, targetProperty) - } - } - - @OptIn(ExperimentalSerializationApi::class) - private fun createExampleFromSchema(jsonElement: JsonElement): JsonElement { - return when { - jsonElement is JsonObject && jsonElement.containsKey("type") -> { - when (jsonElement["type"]?.jsonPrimitive?.content) { - "object" -> { - val properties = jsonElement["properties"] as? JsonObject - val resultMap = mutableMapOf() - properties?.forEach { (key, value) -> - resultMap[key] = createExampleFromSchema(value) - } - JsonObject(resultMap) - } - "array" -> { - val items = jsonElement["items"] - val exampleItems = items?.let { createExampleFromSchema(it) } - JsonArray(listOfNotNull(exampleItems)) + val tools = request.tools.orEmpty() + val additionalMessage = + if (tools.isEmpty()) null + else + user( + """ + + ${chatCompletionsAvailableToolsInstructions(format, tools)} + + ${if (choiceName != null) "$choiceName" else ""} + """ + .trimIndent() + ) + val modifiedRequest = + request.copy( + toolChoice = null, + tools = emptyList(), + messages = listOfNotNull(additionalMessage) + request.messages, + stop = format.stopOn() + ) + return createChatCompletionStream(modifiedRequest).map { response -> + response.copy( + choices = + response.choices.map { choice -> + choice.copy( + ChatCompletionStreamResponseDelta( + toolCalls = + listOf( + ChatCompletionMessageToolCallChunk( + index = 0, + function = + ChatCompletionMessageToolCallChunkFunction( + name = choiceName, + arguments = choice.delta.content + ) + ) + ) + ) + ) } - "string" -> JsonPrimitive("default_string") - "number" -> JsonPrimitive(0) - else -> JsonPrimitive(null) - } - } - else -> JsonPrimitive(null) + ) } } + + fun chatCompletionsAvailableToolsInstructions( + format: FunctionCallFormat, + tools: List + ): String = tools.joinToString("\n") { tool -> format.chatCompletionToolInstructions(tool) } } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt index 48053d436..0dceac5bf 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt @@ -13,10 +13,7 @@ import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive -import net.mamoe.yamlkt.Yaml -import net.mamoe.yamlkt.YamlMap -import net.mamoe.yamlkt.literalContentOrNull -import net.mamoe.yamlkt.toYamlElement +import net.mamoe.yamlkt.* class Assistant( val assistantId: String, @@ -166,6 +163,11 @@ class Assistant( }, fileIds = parsed["file_ids"]?.let { (it as List<*>).map { it.toString() } } ?: emptyList(), + metadata = + // turn to Map + (parsed["metadata"] as? YamlMap)?.toContentMap()?.let { + it.mapKeys { (k, _) -> k.toString() }.mapValues { (_, v) -> v.toString() } + } ) return if (assistantRequest.assistantId != null) { val assistant = @@ -197,7 +199,10 @@ class Assistant( instructions = assistantRequest.instructions, tools = assistantTools(assistantRequest), fileIds = assistantRequest.fileIds, - metadata = null // assistantRequest.metadata + metadata = + assistantRequest.metadata?.let { + JsonObject(it.mapValues { (_, v) -> JsonPrimitive(v) }) + } ), toolsConfig = toolsConfig, config = config, diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt index 24315c5cb..b23e592dc 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt @@ -3,6 +3,7 @@ package com.xebia.functional.xef.llm.assistants import com.xebia.functional.openai.ServerSentEvent import com.xebia.functional.openai.generated.model.* import com.xebia.functional.xef.Config +import com.xebia.functional.xef.llm.assistants.RunDelta.Companion.RunDeltaEvent.* import kotlin.jvm.JvmInline import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -134,6 +135,44 @@ sealed interface RunDelta { @JvmInline @Serializable value class Unknown(val event: ServerSentEvent) : RunDelta companion object { + + fun toServerSentEvent(runDelta: RunDelta): ServerSentEvent? { + return when (runDelta) { + is MessageCompleted -> runDelta.serverSentEventOf(ThreadMessageCompleted) + is MessageCreated -> runDelta.serverSentEventOf(ThreadMessageCreated) + is MessageDelta -> runDelta.serverSentEventOf(ThreadMessageDelta) + is MessageInProgress -> runDelta.serverSentEventOf(ThreadMessageInProgress) + is MessageIncomplete -> runDelta.serverSentEventOf(ThreadMessageIncomplete) + is RunCancelled -> runDelta.serverSentEventOf(ThreadRunCancelled) + is RunCancelling -> runDelta.serverSentEventOf(ThreadRunCancelling) + is RunCompleted -> runDelta.serverSentEventOf(ThreadRunCompleted) + is RunCreated -> runDelta.serverSentEventOf(ThreadRunCreated) + is RunExpired -> runDelta.serverSentEventOf(ThreadRunExpired) + is RunFailed -> runDelta.serverSentEventOf(ThreadRunFailed) + is RunInProgress -> runDelta.serverSentEventOf(ThreadRunInProgress) + is RunQueued -> runDelta.serverSentEventOf(ThreadRunQueued) + is RunRequiresAction -> runDelta.serverSentEventOf(ThreadRunRequiresAction) + is RunStepCancelled -> runDelta.serverSentEventOf(ThreadRunStepCancelled) + is RunStepCompleted -> runDelta.serverSentEventOf(ThreadRunStepCompleted) + is RunStepCreated -> runDelta.serverSentEventOf(ThreadRunStepCreated) + is RunStepDelta -> runDelta.serverSentEventOf(ThreadRunStepDelta) + is RunStepExpired -> runDelta.serverSentEventOf(ThreadRunStepExpired) + is RunStepFailed -> runDelta.serverSentEventOf(ThreadRunStepFailed) + is RunStepInProgress -> runDelta.serverSentEventOf(ThreadRunStepInProgress) + is RunSubmitToolOutputs -> null + is ThreadCreated -> runDelta.serverSentEventOf(RunDeltaEvent.ThreadCreated) + is Unknown -> runDelta.event + } + } + + private inline fun Delta.serverSentEventOf( + event: RunDeltaEvent + ): ServerSentEvent = + ServerSentEvent( + event = event.value, + data = Config.DEFAULT.json.encodeToJsonElement(serializer(), this) + ) + fun fromServerSentEvent(serverEvent: ServerSentEvent): RunDelta { val data = serverEvent.data ?: error("Expected data in ServerSentEvent for RunDelta") val type = serverEvent.event ?: error("Expected event in ServerSentEvent for RunDelta") @@ -145,77 +184,71 @@ sealed interface RunDelta { return when (event) { RunDeltaEvent.ThreadCreated -> ThreadCreated(json.decodeFromJsonElement(ThreadObject.serializer(), data)) - RunDeltaEvent.ThreadRunCreated -> - RunCreated(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunQueued -> - RunQueued(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunInProgress -> + ThreadRunCreated -> RunCreated(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunQueued -> RunQueued(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunInProgress -> RunInProgress(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunRequiresAction -> + ThreadRunRequiresAction -> RunRequiresAction(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunCompleted -> - RunCompleted(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunFailed -> - RunFailed(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunCancelling -> + ThreadRunCompleted -> RunCompleted(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunFailed -> RunFailed(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunCancelling -> RunCancelling(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunCancelled -> - RunCancelled(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunExpired -> - RunExpired(json.decodeFromJsonElement(RunObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepCreated -> + ThreadRunCancelled -> RunCancelled(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunExpired -> RunExpired(json.decodeFromJsonElement(RunObject.serializer(), data)) + ThreadRunStepCreated -> RunStepCreated(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepInProgress -> + ThreadRunStepInProgress -> RunStepInProgress(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepDelta -> + ThreadRunStepDelta -> RunStepDelta(json.decodeFromJsonElement(RunStepDeltaObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepCompleted -> + ThreadRunStepCompleted -> RunStepCompleted(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepFailed -> + ThreadRunStepFailed -> RunStepFailed(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepCancelled -> + ThreadRunStepCancelled -> RunStepCancelled(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadRunStepExpired -> + ThreadRunStepExpired -> RunStepExpired(json.decodeFromJsonElement(RunStepObject.serializer(), data)) - RunDeltaEvent.ThreadMessageCreated -> + ThreadMessageCreated -> MessageCreated(json.decodeFromJsonElement(MessageObject.serializer(), data)) - RunDeltaEvent.ThreadMessageInProgress -> + ThreadMessageInProgress -> MessageInProgress(json.decodeFromJsonElement(MessageObject.serializer(), data)) - RunDeltaEvent.ThreadMessageDelta -> + ThreadMessageDelta -> MessageDelta(json.decodeFromJsonElement(MessageDeltaObject.serializer(), data)) - RunDeltaEvent.ThreadMessageCompleted -> + ThreadMessageCompleted -> MessageCompleted(json.decodeFromJsonElement(MessageObject.serializer(), data)) - RunDeltaEvent.ThreadMessageIncomplete -> + ThreadMessageIncomplete -> MessageIncomplete(json.decodeFromJsonElement(MessageObject.serializer(), data)) RunDeltaEvent.Error -> Unknown(serverEvent) null -> Unknown(serverEvent) } } - enum class RunDeltaEvent { - ThreadCreated, - ThreadRunCreated, - ThreadRunQueued, - ThreadRunInProgress, - ThreadRunRequiresAction, - ThreadRunCompleted, - ThreadRunFailed, - ThreadRunCancelling, - ThreadRunCancelled, - ThreadRunExpired, - ThreadRunStepCreated, - ThreadRunStepInProgress, - ThreadRunStepDelta, - ThreadRunStepCompleted, - ThreadRunStepFailed, - ThreadRunStepCancelled, - ThreadRunStepExpired, - ThreadMessageCreated, - ThreadMessageInProgress, - ThreadMessageDelta, - ThreadMessageCompleted, - ThreadMessageIncomplete, - Error + enum class RunDeltaEvent(val value: String) { + ThreadCreated("thread.created"), + ThreadRunCreated("thread.run.created"), + ThreadRunQueued("thread.run.queued"), + ThreadRunInProgress("thread.run.in_progress"), + ThreadRunRequiresAction("thread.run.requires_action"), + ThreadRunCompleted("thread.run.completed"), + ThreadRunFailed("thread.run.failed"), + ThreadRunCancelling("thread.run.cancelling"), + ThreadRunCancelled("thread.run.cancelled"), + ThreadRunExpired("thread.run.expired"), + ThreadRunStepCreated("thread.run.step.created"), + ThreadRunStepInProgress("thread.run.step.in_progress"), + ThreadRunStepDelta("thread.run.step.delta"), + ThreadRunStepCompleted("thread.run.step.completed"), + ThreadRunStepFailed("thread.run.step.failed"), + ThreadRunStepCancelled("thread.run.step.cancelled"), + ThreadRunStepExpired("thread.run.step.expired"), + ThreadMessageCreated("thread.message.created"), + ThreadMessageInProgress("thread.message.in_progress"), + ThreadMessageDelta("thread.message.delta"), + ThreadMessageCompleted("thread.message.completed"), + ThreadMessageIncomplete("thread.message.incomplete"), + Error("error") } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/AssistantPersistence.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/AssistantPersistence.kt new file mode 100644 index 000000000..f793daaf6 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/AssistantPersistence.kt @@ -0,0 +1,208 @@ +package com.xebia.functional.xef.llm.assistants.local + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.* +import kotlinx.serialization.json.JsonObject + +object AssistantPersistence { + + interface Assistant { + suspend fun get(assistantId: String): AssistantObject + + suspend fun create(createAssistantRequest: CreateAssistantRequest): AssistantObject + + suspend fun delete(assistantId: String): Boolean + + suspend fun list( + limit: Int?, + order: Assistants.OrderListAssistants?, + after: String?, + before: String? + ): ListAssistantsResponse + + suspend fun modify( + assistantId: String, + modifyAssistantRequest: ModifyAssistantRequest + ): AssistantObject + } + + interface AssistantFiles { + + suspend fun create( + assistantId: String, + createAssistantFileRequest: CreateAssistantFileRequest + ): AssistantFileObject + + suspend fun delete(assistantId: String, fileId: String): Boolean + + suspend fun get(assistantId: String, fileId: String): AssistantFileObject + + suspend fun list( + assistantId: String, + limit: Int?, + order: Assistants.OrderListAssistantFiles?, + after: String?, + before: String? + ): ListAssistantFilesResponse + } + + interface Thread { + suspend fun get(threadId: String): ThreadObject + + suspend fun delete(threadId: String): Boolean + + suspend fun create( + assistantId: String?, + runId: String?, + createThreadRequest: CreateThreadRequest + ): ThreadObject + + suspend fun modify(threadId: String, modifyThreadRequest: ModifyThreadRequest): ThreadObject + } + + interface Message { + + suspend fun get(threadId: String, messageId: String): MessageObject + + suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListMessages?, + after: String?, + before: String? + ): ListMessagesResponse + + suspend fun modify( + threadId: String, + messageId: String, + modifyMessageRequest: ModifyMessageRequest + ): MessageObject + + suspend fun createUserMessage( + threadId: String, + assistantId: String?, + runId: String?, + createMessageRequest: CreateMessageRequest + ): MessageObject { + val fileIds = emptyList() + val metadata = null + val role = MessageObject.Role.user + return createMessage( + threadId, + assistantId ?: "", + runId ?: "", + createMessageRequest.content, + fileIds, + metadata, + role + ) + } + + suspend fun createAssistantMessage( + threadId: String, + assistantId: String, + runId: String, + content: String + ): MessageObject { + val fileIds = emptyList() + val metadata = null + val role = MessageObject.Role.assistant + return createMessage(threadId, assistantId, runId, content, fileIds, metadata, role) + } + + suspend fun createMessage( + threadId: String, + assistantId: String, + runId: String, + content: String, + fileIds: List, + metadata: JsonObject?, + role: MessageObject.Role + ): MessageObject + + suspend fun updateContent(threadId: String, messageId: String, content: String): MessageObject + } + + interface MessageFile { + + suspend fun get(threadId: String, messageId: String, fileId: String): MessageFileObject + + suspend fun list( + threadId: String, + messageId: String, + limit: Int?, + order: Assistants.OrderListMessageFiles?, + after: String?, + before: String? + ): ListMessageFilesResponse + } + + interface Step { + + suspend fun updateToolsStep( + runObject: RunObject, + id: String, + stepCalls: + List + ): RunStepObject + + suspend fun create( + runObject: RunObject, + choice: GeneralAssistants.AssistantDecision, + toolCalls: List, + messageId: String? + ): RunStepObject + + suspend fun createToolsStep( + runObject: RunObject, + toolCalls: List + ): RunStepObject = + create( + runObject = runObject, + choice = GeneralAssistants.AssistantDecision.Tools, + toolCalls = toolCalls, + messageId = null + ) + + suspend fun createMessageStep(runObject: RunObject, messageId: String): RunStepObject = + create( + runObject = runObject, + choice = GeneralAssistants.AssistantDecision.Message, + toolCalls = emptyList(), + messageId = messageId + ) + + suspend fun get(threadId: String, runId: String, stepId: String): RunStepObject + + suspend fun list( + threadId: String, + runId: String, + limit: Int?, + order: Assistants.OrderListRunSteps?, + after: String?, + before: String? + ): ListRunStepsResponse + } + + interface Run { + + suspend fun updateRunToRequireToolOutputs( + runId: String, + selectedTool: GeneralAssistants.SelectedTool + ): RunObject + + suspend fun create(threadId: String, createRunRequest: CreateRunRequest): RunObject + + suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListRuns?, + after: String?, + before: String? + ): ListRunsResponse + + suspend fun get(runId: String): RunObject + + suspend fun modify(runId: String, modifyRunRequest: ModifyRunRequest): RunObject + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/GeneralAssistants.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/GeneralAssistants.kt new file mode 100644 index 000000000..04aeed2bd --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/GeneralAssistants.kt @@ -0,0 +1,632 @@ +package com.xebia.functional.xef.llm.assistants.local + +import com.xebia.functional.openai.ServerSentEvent +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.conversation.Conversation +import com.xebia.functional.xef.conversation.Description +import com.xebia.functional.xef.conversation.MessagePolicy +import com.xebia.functional.xef.llm.PromptCalculator.adaptPromptToConversationAndModel +import com.xebia.functional.xef.llm.assistants.RunDelta +import com.xebia.functional.xef.llm.assistants.RunDelta.MessageDeltaObject +import com.xebia.functional.xef.llm.chatFunction +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.PromptBuilder.Companion.assistant +import com.xebia.functional.xef.prompt.PromptBuilder.Companion.system +import com.xebia.functional.xef.prompt.PromptBuilder.Companion.user +import com.xebia.functional.xef.prompt.ToolCallStrategy +import com.xebia.functional.xef.prompt.configuration.PromptConfiguration +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.assistantObjectToolsInner +import io.ktor.client.request.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.channels.ProducerScope +import kotlinx.coroutines.flow.* +import kotlinx.serialization.Required +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.* + +/** + * @param parent an optional parent Job. If a parent job is passed, cancelling the parent job will + * cancel all launched coroutines by [GeneralAssistants]. This is useful when you want to couple + * the lifecycle of Spring, Ktor, Android, or any other framework to [GeneralAssistants]. When any + * of them close, so will all launched jobs. + */ +class GeneralAssistants( + private val api: Chat, + private val assistantPersistence: AssistantPersistence.Assistant, + private val assistantFilesPersistence: AssistantPersistence.AssistantFiles, + private val threadPersistence: AssistantPersistence.Thread, + private val messagePersistence: AssistantPersistence.Message, + private val messageFilesPersistence: AssistantPersistence.MessageFile, + private val runPersistence: AssistantPersistence.Run, + private val runStepPersistence: AssistantPersistence.Step, + context: CoroutineContext = EmptyCoroutineContext, + parent: Job? = null +) : Assistants { + private val supervisor = SupervisorJob(parent) + private val scope = CoroutineScope(context + supervisor) + + // region Assistants + + override suspend fun getAssistant( + assistantId: String, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject = assistantPersistence.get(assistantId) + + override suspend fun createAssistant( + createAssistantRequest: CreateAssistantRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject = assistantPersistence.create(createAssistantRequest) + + override suspend fun deleteAssistant( + assistantId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteAssistantResponse { + val deleted = assistantPersistence.delete(assistantId) + return DeleteAssistantResponse( + id = assistantId, + deleted = deleted, + `object` = DeleteAssistantResponse.Object.assistant_deleted + ) + } + + override suspend fun listAssistants( + limit: Int?, + order: Assistants.OrderListAssistants?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListAssistantsResponse = assistantPersistence.list(limit, order, after, before) + + override suspend fun modifyAssistant( + assistantId: String, + modifyAssistantRequest: ModifyAssistantRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject = assistantPersistence.modify(assistantId, modifyAssistantRequest) + + // endregion + + // region Assistant files + + override suspend fun createAssistantFile( + assistantId: String, + createAssistantFileRequest: CreateAssistantFileRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantFileObject = assistantFilesPersistence.create(assistantId, createAssistantFileRequest) + + override suspend fun deleteAssistantFile( + assistantId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteAssistantFileResponse { + val deleted = assistantFilesPersistence.delete(assistantId, fileId) + return DeleteAssistantFileResponse( + id = fileId, + deleted = deleted, + `object` = DeleteAssistantFileResponse.Object.assistant_file_deleted + ) + } + + override suspend fun getAssistantFile( + assistantId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): AssistantFileObject = assistantFilesPersistence.get(assistantId, fileId) + + override suspend fun listAssistantFiles( + assistantId: String, + limit: Int?, + order: Assistants.OrderListAssistantFiles?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListAssistantFilesResponse = + assistantFilesPersistence.list(assistantId, limit, order, after, before) + + // endregion + + // region Threads + + override suspend fun getThread( + threadId: String, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject = threadPersistence.get(threadId) + + override suspend fun deleteThread( + threadId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteThreadResponse = + DeleteThreadResponse( + id = threadId, + deleted = threadPersistence.delete(threadId), + `object` = DeleteThreadResponse.Object.thread_deleted + ) + + override suspend fun createThread( + createThreadRequest: CreateThreadRequest?, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject = + threadPersistence.create( + assistantId = null, + runId = null, + createThreadRequest = createThreadRequest ?: CreateThreadRequest() + ) + + override suspend fun modifyThread( + threadId: String, + modifyThreadRequest: ModifyThreadRequest, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject = threadPersistence.modify(threadId, modifyThreadRequest) + + override suspend fun createThreadAndRun( + createThreadAndRunRequest: CreateThreadAndRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + val thread = createThread(createThreadAndRunRequest.thread) + val run = + createRun( + thread.id, + CreateRunRequest( + assistantId = createThreadAndRunRequest.assistantId, + instructions = createThreadAndRunRequest.instructions, + tools = createThreadAndRunRequest.tools?.map { it.assistantObjectToolsInner() }, + metadata = createThreadAndRunRequest.metadata, + model = createThreadAndRunRequest.model, + additionalInstructions = createThreadAndRunRequest.instructions + ) + ) + return run + } + + override fun createThreadAndRunStream( + createThreadAndRunRequest: CreateThreadAndRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow { + TODO("Not yet implemented") + } + + // endregion + + // region Messages + + override suspend fun createMessage( + threadId: String, + createMessageRequest: CreateMessageRequest, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject = + messagePersistence.createUserMessage( + threadId = threadId, + assistantId = null, + runId = null, + createMessageRequest = createMessageRequest + ) + + override suspend fun getMessage( + threadId: String, + messageId: String, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject = messagePersistence.get(threadId, messageId) + + override suspend fun listMessages( + threadId: String, + limit: Int?, + order: Assistants.OrderListMessages?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListMessagesResponse = messagePersistence.list(threadId, limit, order, after, before) + + override suspend fun modifyMessage( + threadId: String, + messageId: String, + modifyMessageRequest: ModifyMessageRequest, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject = messagePersistence.modify(threadId, messageId, modifyMessageRequest) + + // endregion + + // region Message files + + override suspend fun getMessageFile( + threadId: String, + messageId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): MessageFileObject = messageFilesPersistence.get(threadId, messageId, fileId) + + override suspend fun listMessageFiles( + threadId: String, + messageId: String, + limit: Int?, + order: Assistants.OrderListMessageFiles?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListMessageFilesResponse = + messageFilesPersistence.list(threadId, messageId, limit, order, after, before) + + // endregion + + // region Run + + @Description("The decision to run a tool or provide a message depending on th `context`") + @Serializable + data class ToolsOrMessageDecision( + @Required + @Description("Set `tool = 1, message = 0` if `context` requires a tool") + val tool: Int, + @Required + @Description("Set `tool = 0, message = 1` if `context` requires a reply message") + val message: Int, + @Required + @Description( + "A short statement describing the reason why you made the choice of tool or reply message." + ) + val reason: String + ) + + enum class AssistantDecision { + Tools, + Message; + + companion object { + fun fromToolsOrMessageDecision(value: ToolsOrMessageDecision): AssistantDecision { + println("Decision: ${value.reason}") + return when { + value.tool == 1 -> Tools + value.message == 1 -> Message + else -> throw IllegalArgumentException("Invalid value: $value") + } + } + } + } + + private fun createPrompt( + runObject: RunObject, + messages: List, + functions: List = emptyList() + ): Prompt = + Prompt( + model = CreateChatCompletionRequestModel.Custom(runObject.model), + functions = functions, + configuration = + PromptConfiguration { + messagePolicy = MessagePolicy(historyPercent = 100, contextPercent = 0) + } + ) { + +system(runObject.instructions) + +createToolsMessages(runObject.tools) + +createPromptMessages(messages) + } + + private fun createToolsMessages( + tools: List + ): List = + tools.mapNotNull { tool -> + when (tool) { + is AssistantObjectToolsInner.CaseAssistantToolsCode -> + null // TODO implement with sandbox python environment + is AssistantObjectToolsInner.CaseAssistantToolsFunction -> + assistant( + """ + | + |Function: ${tool.value.function.name} + |Description: ${tool.value.function.description} + | + |${tool.value.function.parameters?.let { Json.encodeToString(JsonObject.serializer(), it) }} + | + | + """ + .trimMargin() + ) + is AssistantObjectToolsInner.CaseAssistantToolsRetrieval -> + null // TODO implement with vector store + } + } + + private fun createPromptMessages( + messages: List + ): List = + messages.flatMap { msg -> + msg.content.map { content -> + when (content) { + is MessageObjectContentInner.CaseMessageContentImageFileObject -> + when (msg.role) { + MessageObject.Role.user -> user("Image: ${content.value.imageFile.fileId}") + MessageObject.Role.assistant -> assistant("Image: ${content.value.imageFile.fileId}") + } + is MessageObjectContentInner.CaseMessageContentTextObject -> + when (msg.role) { + MessageObject.Role.user -> user(content.value.text.value) + MessageObject.Role.assistant -> assistant(content.value.text.value) + } + } + } + } + + @Serializable + data class SelectedTool( + @Description("the name of the tool to run") val name: String, + @Description( + "the arguments to pass to the tool, expressed as a json object in a single line string, arguments name extracted from tool JSON schema. Example: {\"argName1\": \"value1\", \"argName2\": \"value2\"}" + ) + val parameters: JsonObject + ) + + private suspend fun ProducerScope.processRun(runObject: RunObject) { + // notify that the run is in progress + send(RunDelta.RunInProgress(runObject)) + val thread = getThread(runObject.threadId) + val messages = listMessages(thread.id, limit = 1, order = Assistants.OrderListMessages.desc) + val decisionPrompt = decisionPrompt(runObject, messages) + val decision = AI(prompt = decisionPrompt, api = api) + val choice = AssistantDecision.fromToolsOrMessageDecision(decision) + val prompt: Prompt = createPrompt(runObject, messages.data) + when (choice) { + AssistantDecision.Tools -> { + val toolsRunStep = runStepPersistence.createToolsStep(runObject, emptyList()) + send(RunDelta.RunStepInProgress(toolsRunStep)) + val selectedTool = toolCalls(prompt, runObject) + val stepCalls = listOf(toolCallsToStepDetails(selectedTool)) + val updatedStep = runStepPersistence.updateToolsStep(runObject, toolsRunStep.id, stepCalls) + send(RunDelta.RunStepCompleted(updatedStep)) + val updatedRun = runPersistence.updateRunToRequireToolOutputs(runObject.id, selectedTool) + send(RunDelta.RunRequiresAction(updatedRun)) + } + AssistantDecision.Message -> { + val message = + messagePersistence.createAssistantMessage( + threadId = thread.id, + assistantId = runObject.assistantId, + runId = runObject.id, + content = "" + ) + val createMessageRunStep = runStepPersistence.createMessageStep(runObject, message.id) + send(RunDelta.RunStepCreated(createMessageRunStep)) + send(RunDelta.MessageInProgress(message)) + val content = StringBuilder() + AI>(prompt = prompt, api = api).collect { partialDelta -> + content.append(partialDelta) + send(messageDelta(message, partialDelta)) + } + val completedMessage = + messagePersistence.updateContent(thread.id, message.id, content.toString()) + send(RunDelta.MessageCompleted(completedMessage)) + send(RunDelta.RunCompleted(runObject)) + } + } + } + + private suspend fun decisionPrompt(runObject: RunObject, messages: ListMessagesResponse): Prompt = + Prompt( + functions = listOf(chatFunction(ToolsOrMessageDecision.serializer().descriptor)), + model = CreateChatCompletionRequestModel.Custom(runObject.model), + toolCallStrategy = runObject.toolCallStrategy() + ) { + +createToolsMessages(runObject.tools) + +createPromptMessages(messages.data) + +user("Please select the tool you would like to run or provide a message.") + } + .adaptPromptToConversationAndModel(Conversation()) + + private fun toolCallsToStepDetails(call: SelectedTool) = + RunStepDetailsToolCallsObjectToolCallsInner.CaseRunStepDetailsToolCallsFunctionObject( + RunStepDetailsToolCallsFunctionObject( + type = RunStepDetailsToolCallsFunctionObject.Type.function, + function = + RunStepDetailsToolCallsFunctionObjectFunction( + name = call.name, + arguments = + Config.DEFAULT.json.encodeToString(JsonObject.serializer(), call.parameters), + output = null + ) + ) + ) + + private suspend fun toolCalls( + currentConversationPrompt: Prompt, + runObject: RunObject + ): SelectedTool { + val functions = functionObjects(runObject) + return AI( + prompt = selectToolPrompt(functions, runObject, currentConversationPrompt), + api = api + ) + } + + private fun RunObject.toolCallStrategy(): ToolCallStrategy = + metadata?.get(ToolCallStrategy.Key)?.let { + when (it) { + is JsonPrimitive -> it.contentOrNull?.let { ToolCallStrategy.valueOf(it) } + else -> null + } + } ?: ToolCallStrategy.Supported + + private fun selectToolPrompt( + functions: List, + runObject: RunObject, + currentConversationPrompt: Prompt + ): Prompt = + Prompt( + functions = functions, + model = CreateChatCompletionRequestModel.Custom(runObject.model), + configuration = + PromptConfiguration { + maxTokens = 1000 + messagePolicy = MessagePolicy(historyPercent = 100, contextPercent = 0) + }, + toolCallStrategy = runObject.toolCallStrategy() + ) { + +currentConversationPrompt + runObject.tools.forEach { tool -> + when (tool) { + is AssistantObjectToolsInner.CaseAssistantToolsCode -> { + // TODO implement with sandbox python environment + } + is AssistantObjectToolsInner.CaseAssistantToolsFunction -> { + +assistant( + """ + |Function: ${tool.value.function.name} + |Description: ${tool.value.function.description} + | + |${tool.value.function.parameters?.let { Json.encodeToString(JsonObject.serializer(), it) }} + | + | + """ + .trimMargin() + ) + } + is AssistantObjectToolsInner.CaseAssistantToolsRetrieval -> { + // TODO implement with vector store + } + } + } + +assistant( + "Please provide the tool you would like to run to answer the user question. Respond in one line with the tool name and arguments and don't use \\n." + ) + } + + private fun functionObjects(runObject: RunObject) = + runObject.tools.mapNotNull { tool -> + when (tool) { + is AssistantObjectToolsInner.CaseAssistantToolsCode -> + null // TODO implement with sandbox python environment + is AssistantObjectToolsInner.CaseAssistantToolsFunction -> tool.value.function + is AssistantObjectToolsInner.CaseAssistantToolsRetrieval -> + null // TODO implement with vector store + } + } + + private fun messageDelta(message: MessageObject, partialDelta: String): RunDelta.MessageDelta = + RunDelta.MessageDelta( + MessageDeltaObject( + id = message.id, + `object` = "message_delta", + delta = + RunDelta.MessageDeltaObjectInner( + content = + listOf( + RunDelta.MessageDeltaObjectInnerContent( + index = 0, + type = "text", + text = RunDelta.MessageDeltaObjectInnerContentText(value = partialDelta) + ) + ) + ) + ) + ) + + override suspend fun createRun( + threadId: String, + createRunRequest: CreateRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject = runPersistence.create(threadId, createRunRequest) + + // .also { runObject -> + // // We remove the parent, such that our scope Job doesn't get overridden + // // This way we inherit the dispatcher, and context from createRun but run on our scope. + // val context = currentCoroutineContext().minusKey(Job) + // scope.launch(context) { + // channelFlow { processRun(runObject) }.singleOrNull() + // } + // } + + override suspend fun cancelRun( + threadId: String, + runId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override suspend fun listRuns( + threadId: String, + limit: Int?, + order: Assistants.OrderListRuns?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListRunsResponse = runPersistence.list(threadId, limit, order, after, before) + + override fun createRunStream( + threadId: String, + createRunRequest: CreateRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow = + channelFlow { + val thread = threadPersistence.get(threadId) + val run = createRun(thread.id, createRunRequest) + processRun(run) + } + .mapNotNull { RunDelta.toServerSentEvent(it) } + + override suspend fun getRun( + threadId: String, + runId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunObject = runPersistence.get(runId) + + override suspend fun modifyRun( + threadId: String, + runId: String, + modifyRunRequest: ModifyRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject = runPersistence.modify(runId, modifyRunRequest) + + // endregion + + // region Run steps + + override suspend fun getRunStep( + threadId: String, + runId: String, + stepId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunStepObject = runStepPersistence.get(threadId, runId, stepId) + + override suspend fun listRunSteps( + threadId: String, + runId: String, + limit: Int?, + order: Assistants.OrderListRunSteps?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListRunStepsResponse = runStepPersistence.list(threadId, runId, limit, order, after, before) + + // endregion + + // region Submit tool outputs + + override suspend fun submitToolOuputsToRun( + threadId: String, + runId: String, + submitToolOutputsRunRequest: SubmitToolOutputsRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override fun submitToolOuputsToRunStream( + threadId: String, + runId: String, + submitToolOutputsRunRequest: SubmitToolOutputsRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow { + TODO("Not yet implemented") + } + + // endregion + + // Guarantee backpressure on cancellation + // override fun close() = runBlocking { + // supervisor.cancelAndJoin() + // } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistantFiles.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistantFiles.kt new file mode 100644 index 000000000..b2fd07867 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistantFiles.kt @@ -0,0 +1,65 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.AssistantFileObject +import com.xebia.functional.openai.generated.model.CreateAssistantFileRequest +import com.xebia.functional.openai.generated.model.ListAssistantFilesResponse +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryAssistantFiles : AssistantPersistence.AssistantFiles { + private val assistantFiles = Atomic.unsafe(emptyMap()) + + override suspend fun create( + assistantId: String, + createAssistantFileRequest: CreateAssistantFileRequest + ): AssistantFileObject { + val uuid = UUID.generateUUID() + val assistantFileObject = + AssistantUtils.assistantFileObject(createAssistantFileRequest, assistantId) + assistantFiles.update { it + (uuid to assistantFileObject) } + return assistantFileObject + } + + override suspend fun get(assistantId: String, fileId: String): AssistantFileObject = + assistantFiles.get().values.firstOrNull { it.id == fileId } + ?: throw Exception("Assistant file not found for id: $fileId") + + override suspend fun delete(assistantId: String, fileId: String): Boolean = + assistantFiles + .updateAndGet { it.filter { (_, assistantFile) -> assistantFile.id != fileId } } + .isNotEmpty() + + override suspend fun list( + assistantId: String, + limit: Int?, + order: Assistants.OrderListAssistantFiles?, + after: String?, + before: String? + ): ListAssistantFilesResponse { + val allAssistantFiles = assistantFiles.get().values.toList() + val sortedAssistantFiles = + when (order) { + Assistants.OrderListAssistantFiles.asc -> allAssistantFiles.sortedBy { it.createdAt } + Assistants.OrderListAssistantFiles.desc -> + allAssistantFiles.sortedByDescending { it.createdAt } + null -> allAssistantFiles + } + val afterAssistantFile = after?.let { sortedAssistantFiles.indexOfFirst { it.id == after } } + val beforeAssistantFile = before?.let { sortedAssistantFiles.indexOfFirst { it.id == before } } + val assistantFilesToReturn = + sortedAssistantFiles + .let { afterAssistantFile?.let { afterIndex -> it.drop(afterIndex + 1) } ?: it } + .let { beforeAssistantFile?.let { beforeIndex -> it.take(beforeIndex) } ?: it } + .let { limit?.let { limit -> it.take(limit) } ?: it } + return ListAssistantFilesResponse( + `object` = "list", + data = assistantFilesToReturn, + firstId = assistantFilesToReturn.firstOrNull()?.id, + lastId = assistantFilesToReturn.lastOrNull()?.id, + hasMore = sortedAssistantFiles.size > assistantFilesToReturn.size + ) + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistants.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistants.kt new file mode 100644 index 000000000..da1834227 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryAssistants.kt @@ -0,0 +1,94 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.model.AssistantObject +import com.xebia.functional.openai.generated.model.CreateAssistantRequest +import com.xebia.functional.openai.generated.model.ListAssistantsResponse +import com.xebia.functional.openai.generated.model.ModifyAssistantRequest +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryAssistants : AssistantPersistence.Assistant { + private val assistants = Atomic.unsafe(emptyMap()) + + override suspend fun create(createAssistantRequest: CreateAssistantRequest): AssistantObject { + val uuid = UUID.generateUUID() + val assistantObject = AssistantUtils.assistantObject(uuid, createAssistantRequest) + assistants.update { it + (uuid to assistantObject) } + return assistantObject + } + + override suspend fun get(assistantId: String): AssistantObject = + assistants.get().values.firstOrNull { it.id == assistantId } + ?: throw Exception("Assistant not found for id: $assistantId") + + override suspend fun delete(assistantId: String): Boolean = + assistants + .updateAndGet { it.filter { (_, assistant) -> assistant.id != assistantId } } + .isNotEmpty() + + override suspend fun list( + limit: Int?, + order: Assistants.OrderListAssistants?, + after: String?, + before: String? + ): ListAssistantsResponse { + val allAssistants = assistants.get().values.toList() + val sortedAssistants = + when (order) { + Assistants.OrderListAssistants.asc -> allAssistants.sortedBy { it.createdAt } + Assistants.OrderListAssistants.desc -> allAssistants.sortedByDescending { it.createdAt } + null -> allAssistants + } + val afterAssistant = after?.let { sortedAssistants.indexOfFirst { it.id == after } } + val beforeAssistant = before?.let { sortedAssistants.indexOfFirst { it.id == before } } + val assistantsToReturn = + sortedAssistants + .let { afterAssistant?.let { afterIndex -> it.drop(afterIndex + 1) } ?: it } + .let { beforeAssistant?.let { beforeIndex -> it.take(beforeIndex) } ?: it } + .let { limit?.let { limit -> it.take(limit) } ?: it } + return ListAssistantsResponse( + `object` = "list", + data = assistantsToReturn, + firstId = assistantsToReturn.firstOrNull()?.id, + lastId = assistantsToReturn.lastOrNull()?.id, + hasMore = sortedAssistants.size > assistantsToReturn.size + ) + } + + override suspend fun modify( + assistantId: String, + modifyAssistantRequest: ModifyAssistantRequest + ): AssistantObject { + val assistant = get(assistantId) + val modifiedAssistant = + AssistantUtils.modifiedAssistantObject(assistant, modifyAssistantRequest) + assistants.update { it + (UUID(assistant.id) to modifiedAssistant) } + return modifiedAssistant + } + + companion object { + operator fun invoke(api: Chat): Assistants { + val assistants = InMemoryAssistants() + val assistantFiles = InMemoryAssistantFiles() + val threads = InMemoryThreads() + val messages = InMemoryMessages() + val messageFiles = InMemoryMessagesFiles() + val runs = InMemoryRuns(assistants) + val runsSteps = InMemoryRunsSteps() + return GeneralAssistants( + api = api, + assistantPersistence = assistants, + assistantFilesPersistence = assistantFiles, + threadPersistence = threads, + messagePersistence = messages, + messageFilesPersistence = messageFiles, + runPersistence = runs, + runStepPersistence = runsSteps + ) + } + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessages.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessages.kt new file mode 100644 index 000000000..e623c1fd6 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessages.kt @@ -0,0 +1,98 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.ListMessagesResponse +import com.xebia.functional.openai.generated.model.MessageObject +import com.xebia.functional.openai.generated.model.ModifyMessageRequest +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.serialization.json.JsonObject +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryMessages : AssistantPersistence.Message { + + private val messages = Atomic.unsafe(emptyMap()) + + override suspend fun get(threadId: String, messageId: String): MessageObject { + return messages.get()[UUID(messageId)] + ?: throw Exception("Message not found for id: $messageId") + } + + override suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListMessages?, + after: String?, + before: String? + ): ListMessagesResponse { + val allMessages = messages.get().values.toList() + val sortedMessages = + when (order) { + Assistants.OrderListMessages.asc -> allMessages.sortedBy { it.createdAt } + Assistants.OrderListMessages.desc -> allMessages.sortedByDescending { it.createdAt } + null -> allMessages + } + val afterMessage = after?.let { sortedMessages.indexOfFirst { it.id == after } } + val beforeMessage = before?.let { sortedMessages.indexOfFirst { it.id == before } } + val messagesToReturn = + sortedMessages + .let { afterMessage?.let { afterIndex -> it.drop(afterIndex + 1) } ?: it } + .let { beforeMessage?.let { beforeIndex -> it.take(beforeIndex) } ?: it } + .let { limit?.let { limit -> it.take(limit) } ?: it } + return ListMessagesResponse( + `object` = "list", + data = messagesToReturn, + firstId = messagesToReturn.firstOrNull()?.id, + lastId = messagesToReturn.lastOrNull()?.id, + hasMore = sortedMessages.size > messagesToReturn.size + ) + } + + override suspend fun modify( + threadId: String, + messageId: String, + modifyMessageRequest: ModifyMessageRequest + ): MessageObject { + val message = get(threadId, messageId) + val modifiedMessage = message.copy(metadata = modifyMessageRequest.metadata ?: message.metadata) + messages.update { it + (UUID(messageId) to modifiedMessage) } + return modifiedMessage + } + + override suspend fun createMessage( + threadId: String, + assistantId: String, + runId: String, + content: String, + fileIds: List, + metadata: JsonObject?, + role: MessageObject.Role + ): MessageObject { + val uuid = UUID.generateUUID() + val message = + AssistantUtils.createMessageObject( + uuid = uuid, + threadId = threadId, + assistantId = assistantId, + runId = runId, + content = content, + fileIds = fileIds, + metadata = metadata, + role = role + ) + messages.update { it + (UUID(message.id) to message) } + return message + } + + override suspend fun updateContent( + threadId: String, + messageId: String, + content: String + ): MessageObject { + val message = get(threadId, messageId) + val updatedMessage = AssistantUtils.modifiedMessageObject(message, content) + messages.update { it + (UUID(messageId) to updatedMessage) } + return updatedMessage + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessagesFiles.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessagesFiles.kt new file mode 100644 index 000000000..ee4b6b285 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryMessagesFiles.kt @@ -0,0 +1,48 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.ListMessageFilesResponse +import com.xebia.functional.openai.generated.model.MessageFileObject +import kotlinx.uuid.UUID + +class InMemoryMessagesFiles : AssistantPersistence.MessageFile { + + private val messagesFiles = Atomic.unsafe(emptyMap()) + + override suspend fun get(threadId: String, messageId: String, fileId: String): MessageFileObject { + return messagesFiles.get()[UUID(fileId)] + ?: throw Exception("Message file not found for id: $fileId") + } + + override suspend fun list( + threadId: String, + messageId: String, + limit: Int?, + order: Assistants.OrderListMessageFiles?, + after: String?, + before: String? + ): ListMessageFilesResponse { + val allMessageFiles = messagesFiles.get().values.toList() + val sortedMessageFiles = + when (order) { + Assistants.OrderListMessageFiles.asc -> allMessageFiles.sortedBy { it.createdAt } + Assistants.OrderListMessageFiles.desc -> allMessageFiles.sortedByDescending { it.createdAt } + null -> allMessageFiles + } + val afterMessageFile = after?.let { sortedMessageFiles.indexOfFirst { it.id == after } } + val beforeMessageFile = before?.let { sortedMessageFiles.indexOfFirst { it.id == before } } + val messageFilesToReturn = + sortedMessageFiles + .let { afterMessageFile?.let { afterIndex -> it.drop(afterIndex + 1) } ?: it } + .let { beforeMessageFile?.let { beforeIndex -> it.take(beforeIndex) } ?: it } + .let { limit?.let { limit -> it.take(limit) } ?: it } + return ListMessageFilesResponse( + `object` = "list", + data = messageFilesToReturn, + firstId = messageFilesToReturn.firstOrNull()?.id, + lastId = messageFilesToReturn.lastOrNull()?.id, + hasMore = sortedMessageFiles.size > messageFilesToReturn.size + ) + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRuns.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRuns.kt new file mode 100644 index 000000000..ec61add8c --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRuns.kt @@ -0,0 +1,76 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.CreateRunRequest +import com.xebia.functional.openai.generated.model.ListRunsResponse +import com.xebia.functional.openai.generated.model.ModifyRunRequest +import com.xebia.functional.openai.generated.model.RunObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryRuns(private val assistants: AssistantPersistence.Assistant) : + AssistantPersistence.Run { + + private val runs = Atomic.unsafe(emptyMap()) + + override suspend fun updateRunToRequireToolOutputs( + runId: String, + selectedTool: GeneralAssistants.SelectedTool + ): RunObject { + val run = get(runId) + val modifiedRun = AssistantUtils.setRunToRequireToolOutouts(run, selectedTool) + runs.update { it + (UUID(runId) to modifiedRun) } + return modifiedRun + } + + override suspend fun create(threadId: String, createRunRequest: CreateRunRequest): RunObject { + val uuid = UUID.generateUUID() + val assistant = assistants.get(createRunRequest.assistantId) + val runObject = AssistantUtils.runObject(uuid, threadId, createRunRequest, assistant) + runs.update { it + (UUID(runObject.id) to runObject) } + return runObject + } + + override suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListRuns?, + after: String?, + before: String? + ): ListRunsResponse { + val allRuns = runs.get().values.toList() + val sortedRuns = + when (order) { + Assistants.OrderListRuns.asc -> allRuns.sortedBy { it.createdAt } + Assistants.OrderListRuns.desc -> allRuns.sortedByDescending { it.createdAt } + null -> allRuns + } + val afterRun = after?.let { sortedRuns.indexOfFirst { it.id == after } } + val beforeRun = before?.let { sortedRuns.indexOfFirst { it.id == before } } + val runsToReturn = + sortedRuns + .let { afterRun?.let { afterIndex -> it.drop(afterIndex + 1) } ?: it } + .let { beforeRun?.let { beforeIndex -> it.take(beforeIndex) } ?: it } + .let { limit?.let { limit -> it.take(limit) } ?: it } + return ListRunsResponse( + `object` = "list", + data = runsToReturn, + firstId = runsToReturn.firstOrNull()?.id, + lastId = runsToReturn.lastOrNull()?.id, + hasMore = sortedRuns.size > runsToReturn.size + ) + } + + override suspend fun get(runId: String): RunObject { + return runs.get()[UUID(runId)] ?: throw Exception("Run not found for id: $runId") + } + + override suspend fun modify(runId: String, modifyRunRequest: ModifyRunRequest): RunObject { + val run = get(runId) + val modifiedRun = run.copy(metadata = modifyRunRequest.metadata ?: run.metadata) + runs.update { it + (UUID(runId) to modifiedRun) } + return modifiedRun + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRunsSteps.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRunsSteps.kt new file mode 100644 index 000000000..8317b343c --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryRunsSteps.kt @@ -0,0 +1,56 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.ListRunStepsResponse +import com.xebia.functional.openai.generated.model.RunObject +import com.xebia.functional.openai.generated.model.RunStepDetailsToolCallsObjectToolCallsInner +import com.xebia.functional.openai.generated.model.RunStepObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryRunsSteps : AssistantPersistence.Step { + + private val steps = Atomic.unsafe(emptyMap()) + + override suspend fun updateToolsStep( + runObject: RunObject, + id: String, + stepCalls: + List + ): RunStepObject { + val uuid = UUID(id) + val step = get(runObject.threadId, runObject.id, id) + val updatedStep = AssistantUtils.updatedRunStepObject(step, stepCalls) + steps.update { it + (uuid to updatedStep) } + return updatedStep + } + + override suspend fun create( + runObject: RunObject, + choice: GeneralAssistants.AssistantDecision, + toolCalls: List, + messageId: String? + ): RunStepObject { + val stepId = UUID.generateUUID() + val stepObject = AssistantUtils.runStepObject(stepId, runObject, choice, toolCalls, messageId) + steps.update { it + (stepId to stepObject) } + return stepObject + } + + override suspend fun get(threadId: String, runId: String, stepId: String): RunStepObject { + return steps.get()[UUID(stepId)] ?: throw Exception("Step not found for id: $stepId") + } + + override suspend fun list( + threadId: String, + runId: String, + limit: Int?, + order: Assistants.OrderListRunSteps?, + after: String?, + before: String? + ): ListRunStepsResponse { + TODO("Not yet implemented") + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryThreads.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryThreads.kt new file mode 100644 index 000000000..a5b0683bc --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/InMemoryThreads.kt @@ -0,0 +1,46 @@ +package com.xebia.functional.xef.llm.assistants.local + +import arrow.fx.coroutines.Atomic +import com.xebia.functional.openai.generated.model.CreateThreadRequest +import com.xebia.functional.openai.generated.model.ModifyThreadRequest +import com.xebia.functional.openai.generated.model.ThreadObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class InMemoryThreads : AssistantPersistence.Thread { + + private val threads = Atomic.unsafe(emptyMap()) + + override suspend fun get(threadId: String): ThreadObject { + return threads.get()[UUID(threadId)] ?: throw Exception("Thread not found for id: $threadId") + } + + override suspend fun delete(threadId: String): Boolean { + val uuid = UUID(threadId) + return !threads.updateAndGet { it.filter { (id, _) -> id != uuid } }.containsKey(uuid) + } + + override suspend fun create( + assistantId: String?, + runId: String?, + createThreadRequest: CreateThreadRequest + ): ThreadObject { + val uuid = UUID.generateUUID() + val threadObject = AssistantUtils.threadObject(uuid, createThreadRequest) + val threadId = UUID(threadObject.id) + threads.update { it + (threadId to threadObject) } + return threadObject + } + + override suspend fun modify( + threadId: String, + modifyThreadRequest: ModifyThreadRequest + ): ThreadObject { + val thread = get(threadId) + val uuid = UUID(threadId) + val modifiedThread = thread.copy(metadata = modifyThreadRequest.metadata ?: thread.metadata) + threads.update { it + (uuid to modifiedThread) } + return modifiedThread + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/utils/AssistantUtils.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/utils/AssistantUtils.kt new file mode 100644 index 000000000..ddee9ab22 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/local/utils/AssistantUtils.kt @@ -0,0 +1,253 @@ +package com.xebia.functional.xef.server.assistants.utils + +import arrow.fx.coroutines.timeInMillis +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.llm.assistants.local.GeneralAssistants +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +object AssistantUtils { + + fun threadObject(uuid: UUID, createThreadRequest: CreateThreadRequest): ThreadObject = + ThreadObject( + id = uuid.toString(), + `object` = ThreadObject.Object.thread, + createdAt = timeInMillis().toInt(), + metadata = createThreadRequest.metadata + ) + + fun runObject( + uuid: UUID, + threadId: String, + request: CreateRunRequest, + assistant: AssistantObject + ): RunObject = + RunObject( + id = uuid.toString(), + `object` = RunObject.Object.thread_run, + createdAt = timeInMillis().toInt(), + threadId = threadId, + assistantId = request.assistantId, + status = RunObject.Status.in_progress, + model = assistant.model, + instructions = request.instructions ?: assistant.instructions ?: "", + tools = request.tools ?: assistant.tools, + fileIds = assistant.fileIds, + metadata = request.metadata ?: assistant.metadata, + usage = null, + requiredAction = null, + lastError = null, + expiresAt = null, + startedAt = timeInMillis().toInt(), + cancelledAt = null, + failedAt = null, + completedAt = null + ) + + fun updatedRunStepObject( + runStepObject: RunStepObject, + stepCalls: + List + ): RunStepObject = + runStepObject.copy( + status = RunStepObject.Status.completed, + stepDetails = + RunStepObjectStepDetails.CaseRunStepDetailsToolCallsObject( + RunStepDetailsToolCallsObject( + type = RunStepDetailsToolCallsObject.Type.tool_calls, + toolCalls = + stepCalls.map { + RunStepDetailsToolCallsObjectToolCallsInner + .CaseRunStepDetailsToolCallsFunctionObject( + RunStepDetailsToolCallsFunctionObject( + id = it.value.id, + type = it.value.type, + function = it.value.function, + ) + ) + } + ) + ) + ) + + fun modifiedMessageObject(messageObject: MessageObject, content: String): MessageObject = + messageObject.copy( + content = + listOf( + MessageObjectContentInner.CaseMessageContentTextObject( + MessageContentTextObject( + type = MessageContentTextObject.Type.text, + text = MessageContentTextObjectText(value = content, annotations = emptyList()) + ) + ) + ) + ) + + fun createMessageObject( + uuid: UUID, + threadId: String, + role: MessageObject.Role, + content: String, + assistantId: String, + runId: String, + fileIds: List, + metadata: JsonObject? + ): MessageObject = + MessageObject( + id = uuid.toString(), + `object` = MessageObject.Object.thread_message, + createdAt = timeInMillis().toInt(), + threadId = threadId, + role = role, + content = + listOf( + MessageObjectContentInner.CaseMessageContentTextObject( + MessageContentTextObject( + type = MessageContentTextObject.Type.text, + text = MessageContentTextObjectText(value = content, annotations = emptyList()) + ) + ) + ), + assistantId = assistantId, + runId = runId, + fileIds = fileIds, + metadata = metadata + ) + + fun modifiedAssistantObject( + assistantObject: AssistantObject, + modifyAssistantRequest: ModifyAssistantRequest + ): AssistantObject = + assistantObject.copy( + name = modifyAssistantRequest.name ?: assistantObject.name, + description = modifyAssistantRequest.description ?: assistantObject.description, + model = modifyAssistantRequest.model ?: assistantObject.model, + instructions = modifyAssistantRequest.instructions ?: assistantObject.instructions, + tools = modifyAssistantRequest.tools ?: assistantObject.tools, + fileIds = modifyAssistantRequest.fileIds ?: assistantObject.fileIds, + metadata = modifyAssistantRequest.metadata ?: assistantObject.metadata + ) + + fun assistantObject(uuid: UUID, createAssistantRequest: CreateAssistantRequest): AssistantObject = + AssistantObject( + id = uuid.toString(), + `object` = AssistantObject.Object.assistant, + createdAt = timeInMillis().toInt(), + name = createAssistantRequest.name, + description = createAssistantRequest.description, + model = createAssistantRequest.model, + instructions = createAssistantRequest.instructions, + tools = createAssistantRequest.tools.orEmpty(), + fileIds = createAssistantRequest.fileIds.orEmpty(), + metadata = createAssistantRequest.metadata + ) + + fun assistantFileObject( + createAssistantFileRequest: CreateAssistantFileRequest, + assistantId: String + ): AssistantFileObject = + AssistantFileObject( + id = createAssistantFileRequest.fileId, + `object` = AssistantFileObject.Object.assistant_file, + createdAt = timeInMillis().toInt(), + assistantId = assistantId, + ) + + fun CreateThreadAndRunRequestToolsInner.assistantObjectToolsInner(): AssistantObjectToolsInner = + when (this) { + is CreateThreadAndRunRequestToolsInner.CaseAssistantToolsCode -> + AssistantObjectToolsInner.CaseAssistantToolsCode( + AssistantToolsCode(AssistantToolsCode.Type.code_interpreter) + ) + is CreateThreadAndRunRequestToolsInner.CaseAssistantToolsFunction -> + AssistantObjectToolsInner.CaseAssistantToolsFunction( + AssistantToolsFunction( + type = AssistantToolsFunction.Type.function, + function = value.function + ) + ) + is CreateThreadAndRunRequestToolsInner.CaseAssistantToolsRetrieval -> + AssistantObjectToolsInner.CaseAssistantToolsRetrieval( + AssistantToolsRetrieval(AssistantToolsRetrieval.Type.retrieval) + ) + } + + fun runStepObject( + stepId: UUID, + runObject: RunObject, + choice: GeneralAssistants.AssistantDecision, + toolCalls: List, + messageId: String? + ): RunStepObject = + RunStepObject( + id = stepId.toString(), + `object` = RunStepObject.Object.thread_run_step, + createdAt = (timeInMillis() / 1000).toInt(), + assistantId = runObject.assistantId, + threadId = runObject.threadId, + runId = runObject.id, + type = + when (choice) { + GeneralAssistants.AssistantDecision.Tools -> RunStepObject.Type.tool_calls + GeneralAssistants.AssistantDecision.Message -> RunStepObject.Type.message_creation + }, + status = RunStepObject.Status.in_progress, + stepDetails = + when (choice) { + GeneralAssistants.AssistantDecision.Tools -> + RunStepObjectStepDetails.CaseRunStepDetailsToolCallsObject( + RunStepDetailsToolCallsObject( + type = RunStepDetailsToolCallsObject.Type.tool_calls, + toolCalls = toolCalls + ) + ) + GeneralAssistants.AssistantDecision.Message -> + RunStepObjectStepDetails.CaseRunStepDetailsMessageCreationObject( + RunStepDetailsMessageCreationObject( + type = RunStepDetailsMessageCreationObject.Type.message_creation, + messageCreation = + RunStepDetailsMessageCreationObjectMessageCreation( + messageId = + messageId ?: error("Message ID is required for message creation step") + ) + ) + ) + }, + lastError = null, + expiredAt = null, + cancelledAt = null, + failedAt = null, + completedAt = null, + metadata = null, + usage = null + ) + + fun setRunToRequireToolOutouts( + runObject: RunObject, + selectedTool: GeneralAssistants.SelectedTool + ): RunObject = + runObject.copy( + requiredAction = + RunObjectRequiredAction( + type = RunObjectRequiredAction.Type.submit_tool_outputs, + submitToolOutputs = + RunObjectRequiredActionSubmitToolOutputs( + toolCalls = + listOf( + RunToolCallObject( + id = UUID.generateUUID().toString(), + type = RunToolCallObject.Type.function, + function = + RunToolCallObjectFunction( + name = selectedTool.name, + arguments = + Json.encodeToString(JsonObject.serializer(), selectedTool.parameters) + ) + ) + ) + ) + ) + ) +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/functions/JsonSchema.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/functions/JsonSchema.kt index 6fd44efd1..48fde7187 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/functions/JsonSchema.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/functions/JsonSchema.kt @@ -22,7 +22,7 @@ import kotlinx.serialization.descriptors.* import kotlinx.serialization.json.* /** Represents the type of json type */ -enum class JsonType(jsonType: String) { +private enum class JsonType(jsonType: String) { /** Represents the json array type */ ARRAY("array"), diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/FunctionCallFormat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/FunctionCallFormat.kt new file mode 100644 index 000000000..95638857e --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/FunctionCallFormat.kt @@ -0,0 +1,24 @@ +package com.xebia.functional.xef.llm.streaming + +import com.xebia.functional.openai.generated.model.ChatCompletionMessageToolCallFunction +import com.xebia.functional.openai.generated.model.ChatCompletionTool +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestStop +import kotlinx.serialization.json.JsonElement + +sealed interface FunctionCallFormat { + fun createExampleFromSchema(schema: JsonElement): String + + fun findPropertyPath(element: String, targetProperty: String): List? + + fun chatCompletionToolInstructions(tool: ChatCompletionTool): String + + fun propertyValue(prop: String, currentArgs: String): JsonElement? + + fun textProperty(propertyValue: JsonElement): String? + + fun stopOn(): CreateChatCompletionRequestStop? + + fun cleanArguments(functionCall: ChatCompletionMessageToolCallFunction): String + + fun argumentsToJsonString(arguments: String): String +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/JsonSupport.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/JsonSupport.kt new file mode 100644 index 000000000..fe23b7d99 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/JsonSupport.kt @@ -0,0 +1,234 @@ +package com.xebia.functional.xef.llm.streaming + +import com.xebia.functional.openai.generated.model.ChatCompletionMessageToolCallFunction +import com.xebia.functional.openai.generated.model.ChatCompletionTool +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestStop +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.llm.streaming.JsonSupport.PropertyType.* +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.* + +object JsonSupport : FunctionCallFormat { + + override fun chatCompletionToolInstructions(tool: ChatCompletionTool): String = + when (tool.type) { + ChatCompletionTool.Type.function -> { + val schema = tool.function.parameters + val parameters = + schema?.let { Config.DEFAULT.json.encodeToString(JsonElement.serializer(), it) } + val function = tool.function + """ + + ${function.name} + ${function.description} + ${parameters} + + + - `response` is exclusively in JSON format, no other text should be present in the `response`. + - The `response` follows the JSON schema defined in the parameters. + - The `response` should be a valid JSON object. + - The `response` follows the `example` structure provided in the instructions. + + + ${toolsInvokeInstructions(tool)} + + + + + Reply in JSON now: + """ + .trimIndent() + } + else -> "" + } + + private fun toolsInvokeInstructions(tool: ChatCompletionTool): String = + tool.function.parameters?.let(::createExampleFromSchema) ?: "" + + private val stringBody = """\"(.*?)\"""".toRegex() + private val numberBody = "(-?\\d+(\\.\\d+)?)".toRegex() + private val booleanBody = """(true|false)""".toRegex() + private val arrayBody = """\[(.*?)\]""".toRegex() + private val objectBody = """\{(.*?)\}""".toRegex() + private val nullBody = """null""".toRegex() + + /** + * The PropertyType enum represents the different types of properties that can be identified from + * JSON. These include STRING, NUMBER, BOOLEAN, ARRAY, OBJECT, NULL, and UNKNOWN. + * + * STRING: Represents a property with a string value. NUMBER: Represents a property with a numeric + * value. BOOLEAN: Represents a property with a boolean value. ARRAY: Represents a property that + * is an array of values. OBJECT: Represents a property that is an object with key-value pairs. + * NULL: Represents a property with a null value. UNKNOWN: Represents a property of unknown type. + */ + private enum class PropertyType { + STRING, + NUMBER, + BOOLEAN, + ARRAY, + OBJECT, + NULL, + UNKNOWN + } + + /** + * Repacks the detected body as a JSON string based on the provided property type. + * + * @param propertyType The property type to determine how the body should be repacked. + * @param detectedBody The detected body to be repacked as a JSON string. + * @return The repacked body as a JSON string. + */ + private fun repackBodyAsJsonString(propertyType: PropertyType, detectedBody: String?): String? = + when (propertyType) { + STRING -> "\"$detectedBody\"" + NUMBER -> detectedBody + BOOLEAN -> detectedBody + ARRAY -> "[$detectedBody]" + OBJECT -> "{$detectedBody}" + NULL -> "null" + else -> null + } + + /** + * Extracts the body from a given input string which may contain potentially malformed json or + * partial json chunk results. + * + * @param propertyType The type of property being extracted. + * @param body The input string to extract the body from. + * @return The extracted body string, or null if the body cannot be found. + */ + private fun extractBody(propertyType: PropertyType, body: String): String? = + when (propertyType) { + STRING -> stringBody.find(body)?.groupValues?.get(1) + NUMBER -> numberBody.find(body)?.groupValues?.get(1) + BOOLEAN -> booleanBody.find(body)?.groupValues?.get(1) + ARRAY -> arrayBody.find(body)?.groupValues?.get(1) + OBJECT -> objectBody.find(body)?.groupValues?.get(1) + NULL -> nullBody.find(body)?.groupValues?.get(1) + else -> null + } + + /** + * Determines the type of property based on a partial chunk of it's body. + * + * @param body The body of the property. + * @return The type of the property. + */ + private fun propertyType(body: String): PropertyType = + when (body.firstOrNull()) { + '"' -> STRING + in '0'..'9' -> NUMBER + 't', + 'f' -> BOOLEAN + '[' -> ARRAY + '{' -> OBJECT + 'n' -> NULL + else -> UNKNOWN + } + + override fun propertyValue(prop: String, currentArgs: String): JsonElement? { + val remainingText = currentArgs.replace("\n", "") + val body = remainingText.substringAfterLast("\"$prop\":").trim() + // detect the type of the property + val propertyType = propertyType(body) + // extract the body of the property or if null don't report it + val detectedBody = extractBody(propertyType, body) ?: return null + // repack the body as a valid JSON string + val propertyValueAsJson = repackBodyAsJsonString(propertyType, detectedBody) + return propertyValueAsJson?.let { Config.DEFAULT.json.parseToJsonElement(it) } + } + + /** + * Searches for the content of the property within a given JsonElement. + * + * @param propertyValue The JsonElement to search within. + * @return The text property as a String, or null if not found. + */ + override fun textProperty(propertyValue: JsonElement): String? { + return when (propertyValue) { + // we don't report on properties holding objects since we report on the properties of the + // object + is JsonObject -> null + is JsonArray -> propertyValue.map { textProperty(it) }.joinToString(", ") + is JsonPrimitive -> propertyValue.content + is JsonNull -> "null" + } + } + + override fun findPropertyPath(element: String, targetProperty: String): List? { + return findPropertyPathTailrec( + listOf(Config.DEFAULT.json.parseToJsonElement(element) to emptyList()), + targetProperty + ) + } + + private tailrec fun findPropertyPathTailrec( + stack: List>>, + targetProperty: String + ): List? { + if (stack.isEmpty()) return null + + val (currentElement, currentPath) = stack.first() + val remainingStack = stack.drop(1) + + return when (currentElement) { + is JsonObject -> { + if (currentElement.containsKey(targetProperty)) { + currentPath + targetProperty + } else { + val newStack = currentElement.entries.map { it.value to (currentPath + it.key) } + findPropertyPathTailrec(remainingStack + newStack, targetProperty) + } + } + is JsonArray -> { + val newStack = currentElement.map { it to currentPath } + findPropertyPathTailrec(remainingStack + newStack, targetProperty) + } + else -> findPropertyPathTailrec(remainingStack, targetProperty) + } + } + + @OptIn(ExperimentalSerializationApi::class) + override fun createExampleFromSchema(schema: JsonElement): String { + val json = + when { + schema is JsonObject && schema.containsKey("type") -> { + when (schema["type"]?.jsonPrimitive?.content) { + "object" -> { + val properties = schema["properties"] as? JsonObject + val resultMap = mutableMapOf() + properties?.forEach { (key, value) -> + resultMap[key] = + Config.DEFAULT.json.parseToJsonElement(createExampleFromSchema(value)) + } + JsonObject(resultMap) + } + "array" -> { + val items = schema["items"] + val exampleItems = + items?.let { Config.DEFAULT.json.parseToJsonElement(createExampleFromSchema(it)) } + JsonArray(listOfNotNull(exampleItems)) + } + "string" -> JsonPrimitive("{{string}}") + "number" -> JsonPrimitive("{{number}}") + "boolean" -> JsonPrimitive("{{boolean}}") + "null" -> JsonPrimitive(null) + else -> JsonPrimitive(null) + } + } + else -> JsonPrimitive(null) + } + + return Config.DEFAULT.json.encodeToString(JsonElement.serializer(), json) + } + + override fun stopOn(): CreateChatCompletionRequestStop? = null + + override fun cleanArguments(functionCall: ChatCompletionMessageToolCallFunction): String { + return "{" + functionCall.arguments.substringAfter("{").substringBeforeLast("}") + "}" + } + + override fun argumentsToJsonString(arguments: String): String { + return arguments + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/XmlSupport.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/XmlSupport.kt new file mode 100644 index 000000000..3dc954880 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/streaming/XmlSupport.kt @@ -0,0 +1,134 @@ +package com.xebia.functional.xef.llm.streaming + +import com.xebia.functional.openai.generated.model.ChatCompletionMessageToolCallFunction +import com.xebia.functional.openai.generated.model.ChatCompletionTool +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestStop +import com.xebia.functional.xef.Config +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.json.* + +object XmlSupport : FunctionCallFormat { + override fun chatCompletionToolInstructions(tool: ChatCompletionTool): String = + when (tool.type) { + ChatCompletionTool.Type.function -> { + val schema = tool.function.parameters + val parameters = schema?.let(::jsonSchemaParametersToXml) + val function = tool.function + """ + + ${function.name} + ${function.description} + ${parameters} + ${toolsInvokeInstructions(tool)} + + """ + .trimIndent() + } + else -> "" + } + + /** tool.function.parameters is a json schema and only the param names should be extracted */ + private fun toolsInvokeInstructions(tool: ChatCompletionTool): String = + """ + + + ${tool.function.name} + + ${tool.function.parameters?.let(::jsonSchemaParametersToXMLCallInstructions)} + + + + """ + .trimIndent() + + private fun jsonSchemaParametersToXMLCallInstructions(schema: JsonObject): String { + val parameters = schema["properties"] as JsonObject + return parameters.entries.joinToString(separator = "\n") { (key, _) -> + "<$key>{{replace-with-value}}" + } + } + + private fun jsonSchemaParametersToXml(schema: JsonObject): String { + val parameters = schema["properties"] as JsonObject + return parameters.entries.joinToString("\n") { (key, value) -> + """ + + $key + ${value.jsonObject["type"]?.jsonPrimitive?.content} + ${value.jsonObject["description"]?.jsonPrimitive?.content} + + """ + .trimIndent() + } + } + + override fun createExampleFromSchema(schema: JsonElement): String { + return schema.jsonObject["properties"]?.let { properties -> + properties.jsonObject.entries.joinToString("\n") { (key, value) -> + """ + <$key>${createExampleFromSchema(value)} + """ + .trimIndent() + } + } ?: "" + } + + override fun findPropertyPath(element: String, targetProperty: String): List? { + return emptyList() + } + + override fun propertyValue(prop: String, currentArgs: String): JsonElement? { + TODO("Not yet implemented") + } + + override fun textProperty(propertyValue: JsonElement): String? { + TODO("Not yet implemented") + } + + override fun stopOn(): CreateChatCompletionRequestStop? = + CreateChatCompletionRequestStop.CaseString("") + + override fun cleanArguments(functionCall: ChatCompletionMessageToolCallFunction): String { + return "" + + functionCall.arguments.substringAfter("").substringBeforeLast("" + } + + /** + * Here arguments is an XML string and we need to convert it to a JSON string + * + * arguments looks like : + * + * Planet Mars + * value + */ + override fun argumentsToJsonString(arguments: String): String { + val converted = convertXmlToJson(arguments) + return converted + } + + // Function to parse XML and convert to JSON + fun convertXmlToJson(xml: String): String { + val parsedXml = parseXml(xml) + return Config.DEFAULT.json.encodeToString(JsonObject.serializer(), parsedXml) + } + + private val xmlElementRegex = "<(\\w+)>(.*?)".toRegex() + + // Simple XML parser to convert XML string to a JSON Object + fun parseXml(xml: String): JsonObject { + val matches = xmlElementRegex.findAll(xml) + val map = mutableMapOf() + + for (match in matches) { + val key = match.groupValues[1] + val value = match.groupValues[2] + if (xmlElementRegex.containsMatchIn(value)) { + map[key] = parseXml(value) + } else { + map[key] = Config.DEFAULT.json.parseToJsonElement(value) + } + } + return JsonObject(map) + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PlatformPromptBuilder.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PlatformPromptBuilder.kt index 0a80019a0..6834d4796 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PlatformPromptBuilder.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PlatformPromptBuilder.kt @@ -8,6 +8,7 @@ import com.xebia.functional.xef.prompt.configuration.PromptConfiguration class PlatformPromptBuilder( private val model: CreateChatCompletionRequestModel, private val functions: List, + private val toolCallStrategy: ToolCallStrategy, private val configuration: PromptConfiguration ) : PromptBuilder { @@ -17,13 +18,16 @@ class PlatformPromptBuilder( elements: List ): List = elements - override fun build(): Prompt = Prompt(model, preprocess(items), functions, configuration) + override fun build(): Prompt = + Prompt(model, preprocess(items), functions, toolCallStrategy, configuration) companion object { fun create( model: CreateChatCompletionRequestModel, functions: List, - configuration: PromptConfiguration - ): PlatformPromptBuilder = PlatformPromptBuilder(model, functions, configuration) + toolCallStrategy: ToolCallStrategy, + configuration: PromptConfiguration, + ): PlatformPromptBuilder = + PlatformPromptBuilder(model, functions, toolCallStrategy, configuration) } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt index 62ad2d355..010460351 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt @@ -7,6 +7,17 @@ import com.xebia.functional.xef.prompt.configuration.PromptConfiguration import kotlin.jvm.JvmOverloads import kotlin.jvm.JvmSynthetic +enum class ToolCallStrategy { + Supported, + InferJsonFromStringResponse, + InferXmlFromStringResponse, + ; + + companion object { + const val Key = "toolCallStrategy" + } +} + /** * A Prompt is a serializable list of messages and its configuration. The messages may involve * different roles. @@ -17,28 +28,27 @@ constructor( val model: CreateChatCompletionRequestModel, val messages: List, val functions: List = emptyList(), + val toolCallStrategy: ToolCallStrategy = ToolCallStrategy.Supported, val configuration: PromptConfiguration = PromptConfiguration.DEFAULTS ) { constructor( model: CreateChatCompletionRequestModel, + toolCallStrategy: ToolCallStrategy, value: String - ) : this(model, listOf(PromptBuilder.user(value)), emptyList()) - - constructor( - model: CreateChatCompletionRequestModel, - value: String, - configuration: PromptConfiguration - ) : this(model, listOf(PromptBuilder.user(value)), emptyList(), configuration) + ) : this(model, listOf(PromptBuilder.user(value)), emptyList(), toolCallStrategy) companion object { @JvmSynthetic operator fun invoke( model: CreateChatCompletionRequestModel, functions: List = emptyList(), + toolCallStrategy: ToolCallStrategy = ToolCallStrategy.Supported, configuration: PromptConfiguration = PromptConfiguration.DEFAULTS, block: PlatformPromptBuilder.() -> Unit ): Prompt = - PlatformPromptBuilder.create(model, functions, configuration).apply { block() }.build() + PlatformPromptBuilder.create(model, functions, toolCallStrategy, configuration) + .apply { block() } + .build() } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt index ee3588efd..6b201c990 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt @@ -81,8 +81,10 @@ interface PromptBuilder { operator fun invoke( model: CreateChatCompletionRequestModel, functions: List, + toolCallStrategy: ToolCallStrategy, configuration: PromptConfiguration - ): PlatformPromptBuilder = PlatformPromptBuilder.create(model, functions, configuration) + ): PlatformPromptBuilder = + PlatformPromptBuilder.create(model, functions, toolCallStrategy, configuration) fun assistant(value: String): ChatCompletionRequestMessage = ChatCompletionRequestMessage.CaseChatCompletionRequestAssistantMessage( diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/conversation/ConversationSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/conversation/ConversationSpec.kt index 2b6d85324..847a648a4 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/conversation/ConversationSpec.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/conversation/ConversationSpec.kt @@ -42,9 +42,9 @@ class ConversationSpec : val vectorStore = scope.store - chatApi.promptMessages(prompt = Prompt(model, "question 1"), scope = scope) + chatApi.promptMessages(prompt = Prompt(model, listOf(user("question 1"))), scope = scope) - chatApi.promptMessages(prompt = Prompt(model, "question 2"), scope = scope) + chatApi.promptMessages(prompt = Prompt(model, listOf(user("question 2"))), scope = scope) val memories = vectorStore.memories(model, conversationId, 10000) @@ -75,7 +75,7 @@ class ConversationSpec : model.modelType().tokensFromMessages(messages.flatMap(::chatCompletionRequestMessages)) messages.forEach { message -> - chatApi.promptMessages(prompt = Prompt(model, message.key), scope = scope) + chatApi.promptMessages(prompt = Prompt(model, listOf(user(message.key))), scope = scope) } val lastRequest = chatApi.requests.last() @@ -111,7 +111,7 @@ class ConversationSpec : model.modelType().tokensFromMessages(messages.flatMap(::chatCompletionRequestMessages)) messages.forEach { message -> - chatApi.promptMessages(prompt = Prompt(model, message.key), scope = scope) + chatApi.promptMessages(prompt = Prompt(model, listOf(user(message.key))), scope = scope) } val lastRequest = chatApi.requests.last() @@ -143,7 +143,7 @@ class ConversationSpec : val response: Answer = chatApi.prompt( - prompt = Prompt(model, question), + prompt = Prompt(model, listOf(user(question))), scope = scope, serializer = Answer.serializer() ) diff --git a/core/src/jvmTest/kotlin/ai/xef/tests/ollama/OllamaTests.kt b/core/src/jvmTest/kotlin/ai/xef/tests/ollama/OllamaTests.kt new file mode 100644 index 000000000..bc826e342 --- /dev/null +++ b/core/src/jvmTest/kotlin/ai/xef/tests/ollama/OllamaTests.kt @@ -0,0 +1,65 @@ +package ai.xef.tests.ollama + +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import com.xebia.functional.xef.prompt.ToolCallStrategy +import io.kotest.common.runBlocking +import kotlinx.serialization.Serializable +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.testcontainers.ollama.OllamaContainer +import org.testcontainers.utility.DockerImageName + +@Disabled +class OllamaTests { + + companion object { + private const val OLLAMA_IMAGE = "ollama/ollama:0.1.26" + private const val NEW_IMAGE_NAME = "ollama/ollama:test" + + val ollama: OllamaContainer by lazy { OllamaContainer(DockerImageName.parse(OLLAMA_IMAGE)) } + + @BeforeAll + @JvmStatic + fun setup() { + ollama.start() + ollama.execInContainer("ollama", "pull", "llama3:8b") + ollama.execInContainer("ollama", "run", "llama3:8b") + ollama.commitToImage(NEW_IMAGE_NAME) + } + + @AfterAll + @JvmStatic + fun teardown() { + ollama.stop() + } + } + + suspend inline fun llama3_8b( + prompt: String, + config: Config = Config(baseUrl = "http://localhost:11434/v1/"), + api: Chat = OpenAI(config = config, logRequests = true).chat, + ): A = + AI( + prompt = prompt, + config = config, + api = api, + model = CreateChatCompletionRequestModel.Custom("gemma:2b"), + toolCallStrategy = ToolCallStrategy.InferJsonFromStringResponse + ) + + @Serializable data class SolarSystemPlanet(val planet: String) + + @Test + fun `test AI chat function`() = runBlocking { + val result = llama3_8b("Your favorite planet") + println(result) + } + + // Add more tests for other functions in the AI companion object +} diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index b29a285e2..5c3788e98 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -15,6 +15,9 @@ java { } dependencies { + implementation(libs.exposed.core) + implementation(libs.flyway.core) + implementation(libs.hikari) implementation(projects.xefCore) implementation(projects.xefEvaluator) implementation(projects.xefFilesystem) @@ -24,6 +27,9 @@ dependencies { implementation(projects.xefReasoning) implementation(projects.xefOpentelemetry) implementation(projects.xefMlflow) + implementation(projects.xefServer) + implementation(projects.xefPostgresql) + implementation(projects.xefAwsBedrock) implementation(libs.suspendApp.core) implementation(libs.kotlinx.serialization.json) implementation(libs.logback) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/LocalAssistant.kt b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/LocalAssistant.kt new file mode 100644 index 000000000..2b05e1610 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/LocalAssistant.kt @@ -0,0 +1,61 @@ +package com.xebia.functional.xef.assistants + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import com.xebia.functional.xef.llm.assistants.Assistant +import com.xebia.functional.xef.llm.assistants.AssistantThread +import com.xebia.functional.xef.llm.assistants.RunDelta +import com.xebia.functional.xef.llm.assistants.Tool +import com.xebia.functional.xef.llm.assistants.local.InMemoryAssistants +import com.xebia.functional.xef.prompt.ToolCallStrategy + +suspend fun getAssistant(assistants: Assistants): Assistant { + // language=yaml + val yamlConfig = + """ + model: "llama3:8b" + name: "Math Assistant" + description: "Help with math" + instructions: + Roleplay: Assistant that helps with math or other general questions. + Instructions: + - For math it has a SumTool. For other questions just reply with the answer. + - If the user input does not contain information to fill the parameters of the tool, + - the assistant will ask for the missing information. + tools: + - type: "function" + name: "SumTool" + metadata: + ${ToolCallStrategy.Key}: ${ToolCallStrategy.InferJsonFromStringResponse.name} + """ + .trimIndent() + val tools = listOf(Tool.toolOf(SumTool())) + + return Assistant.fromConfig(request = yamlConfig, toolsConfig = tools, assistantsApi = assistants) +} + +suspend fun main() { + + val config = Config(baseUrl = "http://localhost:11434/v1/") + val chat = OpenAI(config = config, logRequests = true).chat + val localAssistants = InMemoryAssistants(api = chat) + + val assistant = getAssistant(localAssistants) + + val assistantInfo = assistant.get() + println("assistant: $assistantInfo") + val thread = AssistantThread(api = localAssistants) + println("Enter a message or type 'exit' to quit:") + while (true) { + val input = readlnOrNull() ?: break + if (input == "exit") break + thread.createMessage(input) + thread.run(assistant).collect { + when (it) { + is RunDelta.MessageDelta -> print(it.messageDelta.delta.content.firstOrNull()?.text?.value) + else -> it.printEvent() + } + } + } +} diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/PostgresAssistant.kt b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/PostgresAssistant.kt new file mode 100644 index 000000000..b1a4ada06 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/PostgresAssistant.kt @@ -0,0 +1,42 @@ +package com.xebia.functional.xef.assistants + +import arrow.fx.coroutines.resourceScope +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import com.xebia.functional.xef.llm.assistants.AssistantThread +import com.xebia.functional.xef.llm.assistants.RunDelta +import com.xebia.functional.xef.server.assistants.postgres.PostgresAssistant +import com.xebia.functional.xef.server.services.hikariDataSource +import org.jetbrains.exposed.sql.Database + +suspend fun main() { + + resourceScope { + val config = Config(baseUrl = "http://localhost:11434/v1/") + val chat = OpenAI(config = config, logRequests = true).chat + + val xefDatasource = + hikariDataSource("jdbc:postgresql://localhost:5433/xef_database", "postgres", "postgres") + + Database.connect(xefDatasource) + + val postgresAssistant = PostgresAssistant(api = chat) + val assistant = getAssistant(postgresAssistant) + val assistantInfo = assistant.get() + println("assistant: $assistantInfo") + val thread = AssistantThread(api = postgresAssistant) + println("Enter a message or type 'exit' to quit:") + while (true) { + val input = readlnOrNull() ?: break + if (input == "exit") break + thread.createMessage(input) + thread.run(assistant).collect { + when (it) { + is RunDelta.MessageDelta -> + print(it.messageDelta.delta.content.firstOrNull()?.text?.value) + else -> it.printEvent() + } + } + } + } +} diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockExample.kt b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockExample.kt new file mode 100644 index 000000000..73f201ff3 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockExample.kt @@ -0,0 +1,52 @@ +package com.xebia.functional.xef.aws.bedrock + +import com.xebia.functional.xef.aws.bedrock.models.BedrockModel +import com.xebia.functional.xef.conversation.Description +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.Required +import kotlinx.serialization.Serializable + +@Serializable +@Description("A planet in our solar system.") +data class Planet(@Description("The name of the planet.") val name: String) + +@Serializable +@Description("The sentiment evaluation of a text.") +data class SentimentEvaluation( + @Required + @Description( + """ + 1 if the sentiment is positive, + 0 if the sentiment is neutral, + -1 if the sentiment is negative. + """ + ) + val evaluation: Int +) + +/** + * This is an example of how to use the OpenAI API with the Bedrock runtime. Requires the following + * environment variables to be set: + * - AWS_ACCESS_KEY_ID + * - AWS_SECRET_ACCESS_KEY + * - AWS_REGION_NAME + */ +suspend fun main() { + val planet = BedrockModel.Anthropic.claude3(prompt = "The planet Mars") + println("planet: $planet") + val essayStream = + BedrockModel.Anthropic.claude3>( + prompt = "Write a critique about your less favorite planet: ${planet.name}" + ) + val essay = StringBuilder() + essayStream.collect { + print(it) + essay.append(it) + } + val sentiment = + BedrockModel.Anthropic.claude3( + prompt = "$essay\n\nWhat is the sentiment of the essay?" + ) + println() + println("sentiment: ${sentiment.evaluation}") +} diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockSDKExample.kt b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockSDKExample.kt new file mode 100644 index 000000000..a27f964fb --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockSDKExample.kt @@ -0,0 +1,53 @@ +package com.xebia.functional.xef.aws.bedrock + +import arrow.continuations.SuspendApp +import arrow.fx.coroutines.resourceScope +import aws.smithy.kotlin.runtime.client.LogMode +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.aws.bedrock.conf.loadEnvironment +import com.xebia.functional.xef.prompt.ToolCallStrategy +import kotlinx.coroutines.flow.Flow + +/** + * This is an example of how to use the OpenAI API with the Bedrock SDK runtime. Requires the + * following environment variables to be set: + * - AWS_ACCESS_KEY_ID + * - AWS_SECRET_ACCESS_KEY + * - AWS_REGION_NAME + */ +suspend fun main() = SuspendApp { + resourceScope { + val environment = loadEnvironment() + val runtimeClient = + sdkClient( + awsEnv = environment.aws, + logMode = LogMode.LogResponseWithBody + LogMode.LogRequestWithBody + ) + val bedrockClient = SdkBedrockClient(runtimeClient) + val chat = AnthropicBedrockChat(bedrockClient) + val planet = chat.claude3(prompt = "The planet Mars") + println("planet: $planet") + val essayStream = + chat.claude3>( + prompt = "Write a critique about your less favorite planet: ${planet.name}" + ) + val essay = StringBuilder() + essayStream.collect { + print(it) + essay.append(it) + } + val sentiment = + chat.claude3(prompt = "$essay\n\nWhat is the sentiment of the essay?") + println() + println("sentiment: ${sentiment.evaluation}") + } +} + +private suspend inline fun AnthropicBedrockChat.claude3(prompt: String): A = + AI( + prompt = prompt, + api = this, + model = CreateChatCompletionRequestModel.Custom("anthropic.claude-3-sonnet-20240229-v1:0"), + toolCallStrategy = ToolCallStrategy.InferXmlFromStringResponse + ) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/models/BedrockModel.kt b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/models/BedrockModel.kt new file mode 100644 index 000000000..11ba05b88 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/aws/bedrock/models/BedrockModel.kt @@ -0,0 +1,57 @@ +package com.xebia.functional.xef.aws.bedrock.models + +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.api.OpenAI +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import com.xebia.functional.xef.prompt.ToolCallStrategy + +/** + * Amazon Titan Text G1 - Express 1.x amazon.titan-text-express-v1 Amazon Titan Text G1 - Lite 1.x + * amazon.titan-text-lite-v1 Amazon Titan Text Premier 1.x amazon.titan-text-premier-v1:0 Amazon + * Titan Embeddings G1 - Text 1.x amazon.titan-embed-text-v1 Amazon Titan Embedding Text v2 1.x + * amazon.titan-embed-text-v2:0 Amazon Titan Multimodal Embeddings G1 1.x + * amazon.titan-embed-image-v1 Amazon Titan Image Generator G1 1.x amazon.titan-image-generator-v1 + * Anthropic Claude 2.0 anthropic.claude-v2 Anthropic Claude 2.1 anthropic.claude-v2:1 Anthropic + * Claude 3 Sonnet 1.0 anthropic.claude-3-sonnet-20240229-v1:0 Anthropic Claude 3 Haiku 1.0 + * anthropic.claude-3-haiku-20240307-v1:0 Anthropic Claude 3 Opus 1.0 + * anthropic.claude-3-opus-20240229-v1:0 Anthropic Claude Instant 1.x anthropic.claude-instant-v1 + * AI21 Labs Jurassic-2 Mid 1.x ai21.j2-mid-v1 AI21 Labs Jurassic-2 Ultra 1.x ai21.j2-ultra-v1 + * Cohere Command 14.x cohere.command-text-v14 Cohere Command Light 15.x + * cohere.command-light-text-v14 Cohere Command R 1.x cohere.command-r-v1:0 Cohere Command R+ 1.x + * cohere.command-r-plus-v1:0 Cohere Embed English 3.x cohere.embed-english-v3 Cohere Embed + * Multilingual 3.x cohere.embed-multilingual-v3 Meta Llama 2 Chat 13B 1.x meta.llama2-13b-chat-v1 + * Meta Llama 2 Chat 70B 1.x meta.llama2-70b-chat-v1 Meta Llama 3 8b Instruct 1.x + * meta.llama3-8b-instruct-v1:0 Meta Llama 3 70b Instruct 1.x meta.llama3-70b-instruct-v1:0 Mistral + * AI Mistral 7B Instruct 0.x mistral.mistral-7b-instruct-v0:2 Mistral AI Mixtral 8X7B Instruct 0.x + * mistral.mixtral-8x7b-instruct-v0:1 Mistral AI Mistral Large 1.x mistral.mistral-large-2402-v1:0 + * Stability AI Stable Diffusion XL 0.x stability.stable-diffusion-xl-v0 Stability AI Stable + * Diffusion XL 1.x stability.stable-diffusion-xl-v1 + */ +object BedrockModel { + + suspend inline fun bedrock( + model: String, + prompt: String, + config: Config = + Config( + baseUrl = "http://0.0.0.0:4000", + ), + openAI: OpenAI = OpenAI(config, logRequests = false), + api: Chat = openAI.chat + ): A = + AI( + prompt = prompt, + model = CreateChatCompletionRequestModel.Custom(model), + api = api, + toolCallStrategy = ToolCallStrategy.Supported + ) + + object Anthropic { + + suspend inline fun claude3(prompt: String): A = + bedrock("anthropic.claude-3-sonnet-20240229-v1:0", prompt) + } +} diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/Streams.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/Streams.kt index 8af782b03..dc1279cf4 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/Streams.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/Streams.kt @@ -1,9 +1,26 @@ package com.xebia.functional.xef.dsl.chat +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel import com.xebia.functional.xef.AI -import kotlinx.coroutines.flow.Flow +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import com.xebia.functional.xef.prompt.ToolCallStrategy suspend fun main() { - val result = AI>("List of planets in the solar system") - result.collect(::print) + val result = llama3_8b("Your favorite planet") + println(result) } + +suspend inline fun llama3_8b( + prompt: String, + config: Config = Config(baseUrl = "http://localhost:11434/v1/"), + api: Chat = OpenAI(config = config, logRequests = true).chat, +): A = + AI( + prompt = prompt, + config = config, + api = api, + model = CreateChatCompletionRequestModel.Custom("llama3:8b"), + toolCallStrategy = ToolCallStrategy.InferJsonFromStringResponse + ) diff --git a/examples/src/main/resources/logback.xml b/examples/src/main/resources/logback.xml index fb058d664..0ef56869d 100644 --- a/examples/src/main/resources/logback.xml +++ b/examples/src/main/resources/logback.xml @@ -12,7 +12,7 @@ - + diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index e064e4c66..95cfe4dc1 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -15,7 +15,7 @@ kotest-arrow = "1.4.0" klogging = "6.0.9" uuid = "0.0.22" postgresql = "42.7.3" -testcontainers = "1.19.5" +testcontainers = "1.19.7" hikari = "5.1.0" dokka = "1.9.20" logback = "1.5.5" @@ -39,7 +39,7 @@ progressbar = "0.10.0" jmf = "2.1.1e" mp3-wav-converter = "1.0.4" yamlkt="0.13.0" - +aws-sdk="1.2.9" [libraries] arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" } @@ -115,7 +115,9 @@ opentelemetry-extension-kotlin = { module = "io.opentelemetry:opentelemetry-exte progressbar = { module = "me.tongfei:progressbar", version.ref = "progressbar" } jmf = { module = "javax.media:jmf", version.ref = "jmf" } mp3-wav-converter = { module = "com.sipgate:mp3-wav", version.ref = "mp3-wav-converter" } - +ollama-testcontainers = { module = "org.testcontainers:ollama", version.ref = "testcontainers" } +aws-bedrock = { module = "aws.sdk.kotlin:bedrock", version.ref = "aws-sdk" } +aws-bedrock-runtime = { module = "aws.sdk.kotlin:bedrockruntime", version.ref = "aws-sdk" } [bundles] diff --git a/integrations/aws/bedrock/build.gradle.kts b/integrations/aws/bedrock/build.gradle.kts new file mode 100644 index 000000000..4fb3a242f --- /dev/null +++ b/integrations/aws/bedrock/build.gradle.kts @@ -0,0 +1,69 @@ +plugins { + id(libs.plugins.kotlin.multiplatform.get().pluginId) + id(libs.plugins.kotlinx.serialization.get().pluginId) + alias(libs.plugins.spotless) + alias(libs.plugins.arrow.gradle.publish) + alias(libs.plugins.semver.gradle) + alias(libs.plugins.detekt) +} + +dependencies { detektPlugins(project(":detekt-rules")) } + +detekt { + toolVersion = "1.23.1" + source.setFrom(files("src/commonMain/kotlin", "src/jvmMain/kotlin")) + config.setFrom("../../../config/detekt/detekt.yml") + autoCorrect = true +} + +repositories { mavenCentral() } + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + toolchain { languageVersion = JavaLanguageVersion.of(11) } +} + +kotlin { + jvm() + sourceSets { + val commonMain by getting { + dependencies { + api(projects.xefCore) + implementation(libs.bundles.ktor.client) + implementation(libs.uuid) + implementation(libs.kotlinx.datetime) + api(libs.aws.bedrock) + api(libs.aws.bedrock.runtime) + } + } + val jvmMain by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + } +} + +spotless { + kotlin { + target("**/*.kt") + ktfmt().googleStyle().configure { it.setRemoveUnusedImport(true) } + } +} + +tasks { + withType().configureEach { + dependsOn(":detekt-rules:assemble") + autoCorrect = true + } + named("detektJvmMain") { + dependsOn(":detekt-rules:assemble") + getByName("build").dependsOn(this) + } + named("detekt") { + dependsOn(":detekt-rules:assemble") + getByName("build").dependsOn(this) + } + withType { dependsOn(withType()) } +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AWSBedrockModelAdapter.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AWSBedrockModelAdapter.kt new file mode 100644 index 000000000..ffd3eef19 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AWSBedrockModelAdapter.kt @@ -0,0 +1,27 @@ +package com.xebia.functional.xef.aws.bedrock + +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelWithResponseStreamResponse +import aws.sdk.kotlin.services.bedrockruntime.model.Trace +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest +import com.xebia.functional.openai.generated.model.CreateChatCompletionResponse +import com.xebia.functional.openai.generated.model.CreateChatCompletionStreamResponse +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.JsonObject + +interface AWSBedrockModelAdapter { + val region: String + + fun chatCompletionRequest(request: CreateChatCompletionRequest): JsonObject + + fun chatCompletionStreamRequest(request: CreateChatCompletionRequest): JsonObject + + fun chatCompletionsResponse(response: JsonObject): CreateChatCompletionResponse + + val trace: Trace + val guardrailIdentifier: String? + val guardrailVersion: String? + + fun chatCompletionsSteamResponse( + response: InvokeModelWithResponseStreamResponse + ): Flow +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AnthropicBedrockChat.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AnthropicBedrockChat.kt new file mode 100644 index 000000000..04f461203 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AnthropicBedrockChat.kt @@ -0,0 +1,260 @@ +package com.xebia.functional.xef.aws.bedrock + +import arrow.fx.coroutines.mapIndexed +import arrow.fx.coroutines.timeInMillis +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.aws.bedrock.SdkBedrockClient.ChatCompletionResponseEvent.* +import io.ktor.client.request.* +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.JsonElement +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class AnthropicBedrockChat(private val client: BedrockClient) : Chat { + + override suspend fun createChatCompletion( + createChatCompletionRequest: CreateChatCompletionRequest, + configure: HttpRequestBuilder.() -> Unit + ): CreateChatCompletionResponse = + client + .runInference( + anthropicRequest(createChatCompletionRequest), + awsFoundationModel(createChatCompletionRequest) + ) + .let { response -> createChatCompletionResponse(response, createChatCompletionRequest) } + + private fun awsFoundationModel( + createChatCompletionRequest: CreateChatCompletionRequest + ): AwsFoundationModel = + (AwsFoundationModel.entries.find { it.awsName == createChatCompletionRequest.model.value } + ?: throw IllegalArgumentException("Model not found")) + + private fun createChatCompletionResponse( + response: ChatCompletionResponse, + createChatCompletionRequest: CreateChatCompletionRequest + ): CreateChatCompletionResponse = + CreateChatCompletionResponse( + id = response.id, + choices = + response.content.mapIndexed { index, message -> + CreateChatCompletionResponseChoicesInner( + index = index, + finishReason = CreateChatCompletionResponseChoicesInner.FinishReason.stop, + message = + ChatCompletionResponseMessage( + content = message.text, + role = ChatCompletionResponseMessage.Role.assistant + ), + logprobs = null + ) + }, + created = (timeInMillis() / 1000).toInt(), + model = createChatCompletionRequest.model.value, + `object` = CreateChatCompletionResponse.Object.chat_completion, + usage = + response.usage?.run { + CompletionUsage( + completionTokens = outputTokens ?: 0, + promptTokens = inputTokens ?: 0, + totalTokens = + inputTokens?.let { input -> outputTokens?.let { output -> input + output } } ?: 0 + ) + } + ) + + private fun anthropicRequest( + createChatCompletionRequest: CreateChatCompletionRequest + ): JsonElement { + val systemMessage = + createChatCompletionRequest.messages + .filterIsInstance() + .joinToString("\n") { it.value.content } + val remainingMessages = + createChatCompletionRequest.messages.filterNot { + it is ChatCompletionRequestMessage.CaseChatCompletionRequestSystemMessage + } + // the remaining messages should not have messages of the same role one after the other + // so we need to alternate between assistant and user messages + // and ensure that the first message is a user message + // additionally if messages of the same role are consecutive they should be combined into a + // single message + + val cleanedMessages = enforceAnthropicFormat(remainingMessages) + + return AnthropicMessagesRequestBody( + system = systemMessage, + topP = createChatCompletionRequest.topP, + temperature = createChatCompletionRequest.temperature, + maxTokens = createChatCompletionRequest.maxTokens ?: 3000, + stopSequences = + when (val stop = createChatCompletionRequest.stop) { + is CreateChatCompletionRequestStop.CaseString -> listOf(stop.value) + is CreateChatCompletionRequestStop.CaseStrings -> stop.value + null -> null + }, + messages = cleanedMessages, + // tools = createChatCompletionRequest.tools?.map { + // AnthropicChatCompletionTool( + // name = it.function.name, + // description = it.function.description ?: "", + // inputSchema = it.function.parameters ?: JsonObject(emptyMap()) + // ) + // }, //TODO enable when bedrock support anthropic tool calling + ) + .let { + Config.DEFAULT.json.encodeToJsonElement(AnthropicMessagesRequestBody.serializer(), it) + } + } + + private fun enforceAnthropicFormat( + remainingMessages: List + ): List = + remainingMessages.fold(mutableListOf()) { acc, message -> + if (acc.isEmpty()) { + acc.add(message) + } else { + val lastMessage = acc.last() + val lastContent = extractContent(lastMessage) + val messageContent = extractContent(message) + if ( + lastMessage is ChatCompletionRequestMessage.CaseChatCompletionRequestUserMessage && + message is ChatCompletionRequestMessage.CaseChatCompletionRequestUserMessage + ) { + acc[acc.lastIndex] = + ChatCompletionRequestMessage.CaseChatCompletionRequestUserMessage( + value = + ChatCompletionRequestUserMessage( + role = ChatCompletionRequestUserMessage.Role.user, + content = + ChatCompletionRequestUserMessageContent.CaseString( + value = "${lastContent}\n${messageContent}" + ) + ) + ) + } else if ( + lastMessage is ChatCompletionRequestMessage.CaseChatCompletionRequestSystemMessage && + message is ChatCompletionRequestMessage.CaseChatCompletionRequestSystemMessage + ) { + acc[acc.lastIndex] = + ChatCompletionRequestMessage.CaseChatCompletionRequestSystemMessage( + value = + ChatCompletionRequestSystemMessage( + role = ChatCompletionRequestSystemMessage.Role.system, + content = "${lastContent}\n${messageContent}" + ) + ) + } else if ( + lastMessage is ChatCompletionRequestMessage.CaseChatCompletionRequestAssistantMessage && + message is ChatCompletionRequestMessage.CaseChatCompletionRequestAssistantMessage + ) { + acc[acc.lastIndex] = + ChatCompletionRequestMessage.CaseChatCompletionRequestAssistantMessage( + value = + ChatCompletionRequestAssistantMessage( + role = ChatCompletionRequestAssistantMessage.Role.assistant, + content = "${lastContent}\n${messageContent}" + ) + ) + } else { + acc.add(message) + } + } + acc + } + + private fun extractContent(lastMessage: ChatCompletionRequestMessage): String? { + return when (lastMessage) { + is ChatCompletionRequestMessage.CaseChatCompletionRequestUserMessage -> + when (val content = lastMessage.value.content) { + is ChatCompletionRequestUserMessageContent.CaseString -> content.value + is ChatCompletionRequestUserMessageContent.CaseChatCompletionRequestMessageContentParts -> + content.value.joinToString("\n") { + when (it) { + is ChatCompletionRequestMessageContentPart.CaseChatCompletionRequestMessageContentPartImage -> + it.value.imageUrl.url + is ChatCompletionRequestMessageContentPart.CaseChatCompletionRequestMessageContentPartText -> + it.value.text + } + } + } + is ChatCompletionRequestMessage.CaseChatCompletionRequestSystemMessage -> + lastMessage.value.content + is ChatCompletionRequestMessage.CaseChatCompletionRequestAssistantMessage -> + lastMessage.value.content + is ChatCompletionRequestMessage.CaseChatCompletionRequestFunctionMessage -> + lastMessage.value.content + is ChatCompletionRequestMessage.CaseChatCompletionRequestToolMessage -> + lastMessage.value.content + } + } + + override fun createChatCompletionStream( + createChatCompletionRequest: CreateChatCompletionRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow = + client + .runInferenceWithStream( + anthropicRequest(createChatCompletionRequest), + awsFoundationModel(createChatCompletionRequest) + ) + .mapIndexed { index, event -> + CreateChatCompletionStreamResponse( + id = UUID.generateUUID().toString(), + choices = + when (event) { + is ContentBlockDelta -> + listOf( + CreateChatCompletionStreamResponseChoicesInner( + index = event.index, + finishReason = null, + delta = + ChatCompletionStreamResponseDelta( + content = event.delta.text, + role = ChatCompletionStreamResponseDelta.Role.assistant, + ), + logprobs = null + ) + ) + is ContentBlockStart -> emptyList() + is ContentBlockStop -> emptyList() + is MessageDelta -> + listOf( + CreateChatCompletionStreamResponseChoicesInner( + index = + 0, // TODO: index is always 0 for now, need to update when we have multiple + // messages + finishReason = null, + delta = + ChatCompletionStreamResponseDelta( + content = event.delta.text, + role = ChatCompletionStreamResponseDelta.Role.assistant, + ), + logprobs = null + ) + ) + is MessageStart -> emptyList() + is MessageStop -> + listOf( + CreateChatCompletionStreamResponseChoicesInner( + index = + 0, // TODO: index is always 0 for now, need to update when we have multiple + // messages + finishReason = CreateChatCompletionStreamResponseChoicesInner.FinishReason.stop, + delta = + ChatCompletionStreamResponseDelta( + content = "", + role = ChatCompletionStreamResponseDelta.Role.assistant, + ), + logprobs = null + ) + ) + is Ping -> emptyList() + }, + created = (timeInMillis() / 1000).toInt(), + model = createChatCompletionRequest.model.value, + `object` = CreateChatCompletionStreamResponse.Object.chat_completion_chunk, + ) + } +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AwsFoundationModel.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AwsFoundationModel.kt new file mode 100644 index 000000000..a8817020d --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/AwsFoundationModel.kt @@ -0,0 +1,16 @@ +package com.xebia.functional.xef.aws.bedrock + +enum class AwsFoundationModel(val awsName: String) { + AmazonTitanTextG1Large("amazon.titan-tg1-large"), + AmazonTitanTextG1Express("amazon.titan-text-express-v1"), + AI21LabsJurassic2Mid("ai21.j2-mid-v1"), + AI21LabsJurassic2Ultra("ai21.j2-ultra-v1"), + AnthropicClaudeInstantV1("anthropic.claude-instant-v1"), + AnthropicClaudeV1("anthropic.claude-v1"), + AnthropicClaudeV2("anthropic.claude-v2"), + AnthropicClaudeV2_1("anthropic.claude-v2:1"), + AnthropicClaude3Sonnet20240229V10("anthropic.claude-3-sonnet-20240229-v1:0"), + CohereCommand("cohere.command-text-v14"), + StabilityAIStableDiffusionXLV0("stability.stable-diffusion-xl-v0"), + Empty("empty") +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockAssistants.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockAssistants.kt new file mode 100644 index 000000000..d3fcf8816 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockAssistants.kt @@ -0,0 +1,274 @@ +package com.xebia.functional.xef.aws.bedrock + +import com.xebia.functional.openai.ServerSentEvent +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.* +import io.ktor.client.request.* +import kotlinx.coroutines.flow.Flow + +class BedrockAssistants : Assistants { + override suspend fun cancelRun( + threadId: String, + runId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override suspend fun createAssistant( + createAssistantRequest: CreateAssistantRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject { + TODO("Not yet implemented") + } + + override suspend fun createAssistantFile( + assistantId: String, + createAssistantFileRequest: CreateAssistantFileRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantFileObject { + TODO("Not yet implemented") + } + + override suspend fun createMessage( + threadId: String, + createMessageRequest: CreateMessageRequest, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject { + TODO("Not yet implemented") + } + + override suspend fun createRun( + threadId: String, + createRunRequest: CreateRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override fun createRunStream( + threadId: String, + createRunRequest: CreateRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow { + TODO("Not yet implemented") + } + + override suspend fun createThread( + createThreadRequest: CreateThreadRequest?, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject { + TODO("Not yet implemented") + } + + override suspend fun createThreadAndRun( + createThreadAndRunRequest: CreateThreadAndRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override fun createThreadAndRunStream( + createThreadAndRunRequest: CreateThreadAndRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow { + TODO("Not yet implemented") + } + + override suspend fun deleteAssistant( + assistantId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteAssistantResponse { + TODO("Not yet implemented") + } + + override suspend fun deleteAssistantFile( + assistantId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteAssistantFileResponse { + TODO("Not yet implemented") + } + + override suspend fun deleteThread( + threadId: String, + configure: HttpRequestBuilder.() -> Unit + ): DeleteThreadResponse { + TODO("Not yet implemented") + } + + override suspend fun getAssistant( + assistantId: String, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject { + TODO("Not yet implemented") + } + + override suspend fun getAssistantFile( + assistantId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): AssistantFileObject { + TODO("Not yet implemented") + } + + override suspend fun getMessage( + threadId: String, + messageId: String, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject { + TODO("Not yet implemented") + } + + override suspend fun getMessageFile( + threadId: String, + messageId: String, + fileId: String, + configure: HttpRequestBuilder.() -> Unit + ): MessageFileObject { + TODO("Not yet implemented") + } + + override suspend fun getRun( + threadId: String, + runId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override suspend fun getRunStep( + threadId: String, + runId: String, + stepId: String, + configure: HttpRequestBuilder.() -> Unit + ): RunStepObject { + TODO("Not yet implemented") + } + + override suspend fun getThread( + threadId: String, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject { + TODO("Not yet implemented") + } + + override suspend fun listAssistantFiles( + assistantId: String, + limit: Int?, + order: Assistants.OrderListAssistantFiles?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListAssistantFilesResponse { + TODO("Not yet implemented") + } + + override suspend fun listAssistants( + limit: Int?, + order: Assistants.OrderListAssistants?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListAssistantsResponse { + TODO("Not yet implemented") + } + + override suspend fun listMessageFiles( + threadId: String, + messageId: String, + limit: Int?, + order: Assistants.OrderListMessageFiles?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListMessageFilesResponse { + TODO("Not yet implemented") + } + + override suspend fun listMessages( + threadId: String, + limit: Int?, + order: Assistants.OrderListMessages?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListMessagesResponse { + TODO("Not yet implemented") + } + + override suspend fun listRunSteps( + threadId: String, + runId: String, + limit: Int?, + order: Assistants.OrderListRunSteps?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListRunStepsResponse { + TODO("Not yet implemented") + } + + override suspend fun listRuns( + threadId: String, + limit: Int?, + order: Assistants.OrderListRuns?, + after: String?, + before: String?, + configure: HttpRequestBuilder.() -> Unit + ): ListRunsResponse { + TODO("Not yet implemented") + } + + override suspend fun modifyAssistant( + assistantId: String, + modifyAssistantRequest: ModifyAssistantRequest, + configure: HttpRequestBuilder.() -> Unit + ): AssistantObject { + TODO("Not yet implemented") + } + + override suspend fun modifyMessage( + threadId: String, + messageId: String, + modifyMessageRequest: ModifyMessageRequest, + configure: HttpRequestBuilder.() -> Unit + ): MessageObject { + TODO("Not yet implemented") + } + + override suspend fun modifyRun( + threadId: String, + runId: String, + modifyRunRequest: ModifyRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override suspend fun modifyThread( + threadId: String, + modifyThreadRequest: ModifyThreadRequest, + configure: HttpRequestBuilder.() -> Unit + ): ThreadObject { + TODO("Not yet implemented") + } + + override suspend fun submitToolOuputsToRun( + threadId: String, + runId: String, + submitToolOutputsRunRequest: SubmitToolOutputsRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): RunObject { + TODO("Not yet implemented") + } + + override fun submitToolOuputsToRunStream( + threadId: String, + runId: String, + submitToolOutputsRunRequest: SubmitToolOutputsRunRequest, + configure: HttpRequestBuilder.() -> Unit + ): Flow { + TODO("Not yet implemented") + } +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockClient.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockClient.kt new file mode 100644 index 000000000..c24911d7c --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockClient.kt @@ -0,0 +1,16 @@ +package com.xebia.functional.xef.aws.bedrock + +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.JsonElement + +interface BedrockClient { + suspend fun runInference( + requestBody: JsonElement, + model: AwsFoundationModel + ): ChatCompletionResponse + + fun runInferenceWithStream( + requestBody: JsonElement, + model: AwsFoundationModel + ): Flow +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRequestBody.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRequestBody.kt new file mode 100644 index 000000000..593cdbd26 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRequestBody.kt @@ -0,0 +1,54 @@ +package com.xebia.functional.xef.aws.bedrock + +import com.xebia.functional.openai.generated.model.ChatCompletionRequestMessage +import kotlinx.serialization.Required +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement + +@Serializable +data class AnthropicMessagesRequestBody( + @Required val messages: List, + @Required @SerialName("anthropic_version") val anthropicVersion: String = "bedrock-2023-05-31", + @Required @SerialName("max_tokens") val maxTokens: Int = 3000, + val system: String? = null, + @Required val temperature: Double? = null, + @Required @SerialName("top_p") val topP: Double? = null, + @Required @SerialName("top_k") val topK: Int? = null, + @Required @SerialName("stop_sequences") val stopSequences: List? = null, + + /** + * This is an issue also reported in https://github.com/langchain-ai/langchain/issues/20320 As of + * now, the `tools` field is not documented in the Bedrock API documentation. Users are advised to + * use https://docs.anthropic.com/claude/docs/legacy-tool-use which we will have to implement if + * we use bedrock instead of anthropic's api directly. + * + * Note: The new tool use format is not yet available on Vertex AI or Amazon Bedrock, but is + * coming soon to those platforms. + */ + val tools: List? = null +) + +/* +{ + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + */ + +@Serializable +data class AnthropicChatCompletionTool( + @Required val name: String, + @Required val description: String, + @Required @SerialName("input_schema") val inputSchema: JsonElement +) diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRuntimeModels.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRuntimeModels.kt new file mode 100644 index 000000000..139134ea3 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/BedrockRuntimeModels.kt @@ -0,0 +1,66 @@ +package com.xebia.functional.xef.aws.bedrock + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * The Anthropic Claude model returns the following fields for a messages inference call. + * + * { "id": string, "model": string, "type" : "message", "role" : "assistant", "content": + * [ { "type": "text", "text": string } ], "stop_reason": string, "stop_sequence": string, "usage": + * { "input_tokens": integer, "output_tokens": integer } + * + * } id – The unique identifier for the response. The format and length of the ID might change over + * time. + * + * model – The ID for the Anthropic Claude model that made the request. + * + * stop_reason – The reason why Anthropic Claude stopped generating the response. + * + * end_turn – The model reached a natural stopping point + * + * max_tokens – The generated text exceeded the value of the max_tokens input field or exceeded the + * maximum number of tokens that the model supports.' . + * + * stop_sequence – The model generated one of the stop sequences that you specified in the + * stop_sequences input field. + * + * type – The type of response. The value is always message. + * + * role – The conversational role of the generated message. The value is always assistant. + * + * content – The content generated by the model. Returned as an array. + * + * type – The type of the content. Currently the only supported value is text. + * + * text – The text of the content. + * + * usage – Container for the number of tokens that you supplied in the request and the number tokens + * of that the model generated in the response. + * + * input_tokens – The number of input tokens in the request. + * + * output_tokens – The number tokens of that the model generated in the response. + * + * stop_sequence – The model generated one of the stop sequences that you specified in the + * stop_sequences input field. + */ +@Serializable +data class ChatCompletionResponse( + val id: String, + val model: String, + val type: String, + val role: String, + val content: List, + @SerialName("stop_reason") val stopReason: String? = null, + @SerialName("stop_sequence") val stopSequence: String? = null, + val usage: Usage? = null +) { + @Serializable data class Content(val type: String, val text: String) + + @Serializable + data class Usage( + @SerialName("input_tokens") val inputTokens: Int? = null, + @SerialName("output_tokens") val outputTokens: Int? = null + ) +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/SdkBedrockClient.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/SdkBedrockClient.kt new file mode 100644 index 000000000..405c8d083 --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/SdkBedrockClient.kt @@ -0,0 +1,205 @@ +package com.xebia.functional.xef.aws.bedrock + +import arrow.fx.coroutines.ResourceScope +import arrow.fx.coroutines.closeable +import aws.sdk.kotlin.runtime.auth.credentials.StaticCredentialsProvider +import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelRequest +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest +import aws.smithy.kotlin.runtime.client.LogMode +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.aws.bedrock.conf.Environment +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +suspend fun ResourceScope.sdkClient( + awsEnv: Environment.Aws, + logMode: LogMode = LogMode.Default +): BedrockRuntimeClient = closeable { + BedrockRuntimeClient { + this.logMode = logMode + region = awsEnv.regionName + credentialsProvider = + StaticCredentialsProvider( + aws.smithy.kotlin.runtime.auth.awscredentials.Credentials.invoke( + awsEnv.credentials.accessKeyId, + awsEnv.credentials.secretAccessKey.value + ) + ) + } +} + +class SdkBedrockClient(private val client: BedrockRuntimeClient) : BedrockClient { + override suspend fun runInference( + requestBody: JsonElement, + model: AwsFoundationModel + ): ChatCompletionResponse { + val invokeModelRequest = + InvokeModelRequest.invoke { + modelId = model.awsName + accept = "application/json" + contentType = "application/json" + body = Json.encodeToString(requestBody).toByteArray() + } + + val responseBody = String(client.invokeModel(invokeModelRequest).body) + + return Json.decodeFromString(responseBody) + } + + /** + * event: message_start data: {"type": "message_start", "message": {"id": + * "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": + * [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": + * {"input_tokens": 25, "output_tokens": 1}}} + * + * event: content_block_start data: {"type": "content_block_start", "index": 0, "content_block": + * {"type": "text", "text": ""}} + * + * event: ping data: {"type": "ping"} + * + * event: content_block_delta data: {"type": "content_block_delta", "index": 0, "delta": {"type": + * "text_delta", "text": "Hello"}} + * + * event: content_block_delta data: {"type": "content_block_delta", "index": 0, "delta": {"type": + * "text_delta", "text": "!"}} + * + * event: content_block_stop data: {"type": "content_block_stop", "index": 0} + * + * event: message_delta data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", + * "stop_sequence":null}, "usage": {"output_tokens": 15}} + * + * event: message_stop data: {"type": "message_stop"} + * + * This are what the ChatCompletionStreamEvent types and their data look like We model the data + * part which already includes the type + */ + @Serializable + enum class EventType { + message_start, + content_block_start, + ping, + content_block_delta, + content_block_stop, + message_delta, + message_stop + } + + @Serializable + data class ChatCompletionStreamEvent( + val type: EventType, + val message: ChatCompletionResponseEvent + ) + + @Serializable + sealed class ChatCompletionResponseEvent { + @Serializable + data class MessageStart( + val id: String, + val type: String, + val role: String, + val content: List, + val model: String, + @SerialName("stop_reason") val stopReason: String?, + @SerialName("stop_sequence") val stopSequence: String?, + val usage: ChatCompletionResponse.Usage + ) : ChatCompletionResponseEvent() + + @Serializable + data class ContentBlockStart( + @SerialName("content_block") val contentBlock: ChatCompletionResponse.Content + ) : ChatCompletionResponseEvent() + + @Serializable data class Ping(val type: String) : ChatCompletionResponseEvent() + + @Serializable + data class ContentBlockDelta(val index: Int, val delta: Delta) : ChatCompletionResponseEvent() + + @Serializable data class Delta(val type: String? = null, val text: String? = null) + + @Serializable data class ContentBlockStop(val index: Int) : ChatCompletionResponseEvent() + + @Serializable + data class MessageDelta(val delta: Delta, val usage: ChatCompletionResponse.Usage? = null) : + ChatCompletionResponseEvent() + + @Serializable data class MessageStop(val type: String) : ChatCompletionResponseEvent() + } + + @OptIn(ExperimentalSerializationApi::class) + override fun runInferenceWithStream( + requestBody: JsonElement, + model: AwsFoundationModel + ): Flow { + val json = Json { + explicitNulls = false + ignoreUnknownKeys = true + } + val streamRequest = + InvokeModelWithResponseStreamRequest.invoke { + modelId = model.awsName + accept = "application/json" + contentType = "application/json" + body = Json.encodeToString(requestBody).toByteArray() + } + + return channelFlow { + client.invokeModelWithResponseStream(streamRequest) { response -> + response.body?.collect { responseStream -> + val chunk = responseStream.asChunkOrNull()?.bytes?.toString(Charsets.UTF_8) + chunk?.let { + val jsonEvent = json.parseToJsonElement(it).jsonObject + val event = jsonEvent["type"]?.jsonPrimitive?.content + if (event != null) { + val responseEvent = serverSentEventToChatCompletionStreamEvent(event, chunk) + send(responseEvent) + } + } + } + } + } + } + + fun serverSentEventToChatCompletionStreamEvent( + event: String, + data: String + ): ChatCompletionResponseEvent { + val eventType = EventType.valueOf(event) + val json = Config.DEFAULT.json + val jsonData = json.parseToJsonElement(data).jsonObject + return when (eventType) { + EventType.message_start -> + json.decodeFromJsonElement( + ChatCompletionResponseEvent.MessageStart.serializer(), + jsonData["message"]!! + ) + EventType.content_block_start -> + json.decodeFromJsonElement( + ChatCompletionResponseEvent.ContentBlockStart.serializer(), + jsonData + ) + EventType.ping -> ChatCompletionResponseEvent.Ping(jsonData["type"]!!.jsonPrimitive.content) + EventType.content_block_delta -> + json.decodeFromJsonElement( + ChatCompletionResponseEvent.ContentBlockDelta.serializer(), + jsonData + ) + EventType.content_block_stop -> + ChatCompletionResponseEvent.ContentBlockStop( + jsonData["index"]!!.jsonPrimitive.content.toInt() + ) + EventType.message_delta -> + json.decodeFromJsonElement(ChatCompletionResponseEvent.MessageDelta.serializer(), jsonData) + EventType.message_stop -> + ChatCompletionResponseEvent.MessageStop(jsonData["type"]!!.jsonPrimitive.content) + } + } +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/anthropic/Anthropic.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/anthropic/Anthropic.kt new file mode 100644 index 000000000..d5c679bab --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/anthropic/Anthropic.kt @@ -0,0 +1,28 @@ +package com.xebia.functional.xef.aws.bedrock.anthropic + +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelWithResponseStreamResponse +import aws.sdk.kotlin.services.bedrockruntime.model.Trace +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest +import com.xebia.functional.openai.generated.model.CreateChatCompletionResponse +import com.xebia.functional.openai.generated.model.CreateChatCompletionStreamResponse +import com.xebia.functional.xef.aws.bedrock.AWSBedrockModelAdapter +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.JsonObject + +data class Anthropic( + override val region: String, + override val trace: Trace = Trace.Disabled, + override val guardrailIdentifier: String? = null, + override val guardrailVersion: String? = null +) : AWSBedrockModelAdapter { + override fun chatCompletionRequest(request: CreateChatCompletionRequest): JsonObject = TODO() + + override fun chatCompletionStreamRequest(request: CreateChatCompletionRequest): JsonObject = + TODO() + + override fun chatCompletionsResponse(response: JsonObject): CreateChatCompletionResponse = TODO() + + override fun chatCompletionsSteamResponse( + response: InvokeModelWithResponseStreamResponse + ): Flow = TODO() +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/Environment.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/Environment.kt new file mode 100644 index 000000000..af663c45a --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/Environment.kt @@ -0,0 +1,50 @@ +package com.xebia.functional.xef.aws.bedrock.conf + +import arrow.core.NonEmptyList +import arrow.core.getOrElse +import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.zipOrAccumulate +import conf.ValidationError +import java.lang.System.getenv + +@JvmInline +value class Secret(val value: String) { + override fun toString(): String { + return value.replaceRange(0, value.length - 3, "*") + } +} + +fun String.secret(): Secret = Secret(this) + +data class Environment(val aws: Aws = Aws()) { + data class Aws( + val credentials: Credentials = Credentials(), + val regionName: String = getenv("AWS_REGION_NAME") ?: "us-east-1" + ) { + data class Credentials( + val accessKeyId: String = getenv("AWS_ACCESS_KEY_ID"), + val secretAccessKey: Secret = getenv("AWS_SECRET_ACCESS_KEY").secret() + ) + } +} + +fun loadEnvironment(): Environment { + val environment = Environment() + return either, Environment> { + zipOrAccumulate( + { + ensure(environment.aws.credentials.accessKeyId.isNotBlank()) { + raise(ValidationError.AwsAccessKeyIdNotProvided) + } + }, + { + ensure(environment.aws.credentials.secretAccessKey.value.isNotBlank()) { + raise(ValidationError.AwsSecretAccessKeyNotProvided) + } + }, + { _, _ -> environment } + ) + } + .getOrElse { throw RuntimeException(it.joinToString(transform = ValidationError::toString)) } +} diff --git a/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/ValidationError.kt b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/ValidationError.kt new file mode 100644 index 000000000..a55ba3aca --- /dev/null +++ b/integrations/aws/bedrock/src/commonMain/kotlin/com/xebia/functional/xef/aws/bedrock/conf/ValidationError.kt @@ -0,0 +1,7 @@ +package conf + +sealed interface ValidationError { + data object AwsAccessKeyIdNotProvided : ValidationError + + data object AwsSecretAccessKeyNotProvided : ValidationError +} diff --git a/openai-client/generator/config/api.mustache b/openai-client/generator/config/api.mustache index 026d62207..f42895507 100644 --- a/openai-client/generator/config/api.mustache +++ b/openai-client/generator/config/api.mustache @@ -156,7 +156,7 @@ fun {{classname}}(client: HttpClient, config: Config): {{classname}} = object : {{#queryParams}} parameter("{{baseName}}", {{#isContainer}}toMultiValue(this, "{{collectionFormat}}"){{/isContainer}}{{^isContainer}}listOf("${{{paramName}}}"){{/isContainer}}) {{/queryParams}} - url { path("{{path}}"{{#pathParams}}.replace("{" + "{{baseName}}" + "}", {{#isContainer}}{{paramName}}.joinToString(","){{/isContainer}}{{^isContainer}}"${{{paramName}}}"{{/isContainer}}){{/pathParams}}) } + url { path("{{#lambda.dropslash}}{{path}}{{/lambda.dropslash}}"{{#pathParams}}.replace("{" + "{{baseName}}" + "}", {{#isContainer}}{{paramName}}.joinToString(","){{/isContainer}}{{^isContainer}}"${{{paramName}}}"{{/isContainer}}){{/pathParams}}) } {{#hasBodyParam}} val element = Json.encodeToJsonElement({{#lambda.serializer}}{{#bodyParam}}{{{baseName}}}{{/bodyParam}}{{/lambda.serializer}}, {{#bodyParam}}{{{paramName}}}{{/bodyParam}}) val jsObject = JsonObject(element.jsonObject + Pair("stream", JsonPrimitive(true))) @@ -190,4 +190,4 @@ fun {{classname}}(client: HttpClient, config: Config): {{classname}} = object : {{/vendorExtensions.x-streaming}} {{/operation}} } -{{/operations}} \ No newline at end of file +{{/operations}} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/PostgresAssistant.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/PostgresAssistant.kt new file mode 100644 index 000000000..35d680338 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/PostgresAssistant.kt @@ -0,0 +1,20 @@ +package com.xebia.functional.xef.server.assistants.postgres + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.api.Chat +import com.xebia.functional.xef.llm.assistants.local.GeneralAssistants +import com.xebia.functional.xef.server.assistants.postgres.tables.* + +object PostgresAssistant { + operator fun invoke(api: Chat): Assistants = + GeneralAssistants( + api = api, + assistantPersistence = AssistantsTable, + assistantFilesPersistence = AssistantsFilesTable, + threadPersistence = ThreadsTable, + messagePersistence = MessagesTable, + messageFilesPersistence = MessagesFilesTable, + runPersistence = RunsTable, + runStepPersistence = RunsStepsTable + ) +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsFilesTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsFilesTable.kt new file mode 100644 index 000000000..f2c3ad22b --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsFilesTable.kt @@ -0,0 +1,89 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.assistantFileObject +import com.xebia.functional.xef.server.db.tables.format +import java.util.* +import org.jetbrains.exposed.sql.* +import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction + +object AssistantsFilesTable : Table("xef_assistants_files"), AssistantPersistence.AssistantFiles { + + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun create( + assistantId: String, + createAssistantFileRequest: CreateAssistantFileRequest + ): AssistantFileObject = transaction { + val uuid = UUID.randomUUID() + val assistantFileObject = assistantFileObject(createAssistantFileRequest, assistantId) + AssistantsFilesTable.insert { + it[id] = uuid + it[data] = assistantFileObject + } + assistantFileObject + } + + override suspend fun delete(assistantId: String, fileId: String): Boolean = transaction { + AssistantsFilesTable.deleteWhere { + data.extract("id") eq fileId and (data.extract("assistantId") eq assistantId) + } > 0 + } + + override suspend fun get(assistantId: String, fileId: String): AssistantFileObject = transaction { + AssistantsFilesTable.select(data) + .where { + (AssistantsFilesTable.id eq UUID.fromString(fileId)) and + (data.extract("assistantId") eq assistantId) + } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Assistant file not found for id: $fileId") + } + + override suspend fun list( + assistantId: String, + limit: Int?, + order: Assistants.OrderListAssistantFiles?, + after: String?, + before: String? + ): ListAssistantFilesResponse { + val query = + AssistantsFilesTable.select(data).where { data.extract("assistantId") eq assistantId } + val sortedQuery = + when (order) { + Assistants.OrderListAssistantFiles.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListAssistantFiles.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterFile = after?.let { UUID.fromString(it) } + val beforeFile = before?.let { UUID.fromString(it) } + val afterFileIndex = + afterFile?.let { sortedQuery.indexOfFirst { it[data].id == afterFile.toString() } } + val beforeFileIndex = + beforeFile?.let { sortedQuery.indexOfFirst { it[data].id == beforeFile.toString() } } + val slicedQuery = + when { + afterFileIndex != null -> sortedQuery.drop(afterFileIndex + 1) + beforeFileIndex != null -> sortedQuery.take(beforeFileIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + return ListAssistantFilesResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsTable.kt new file mode 100644 index 000000000..d5d2e3a24 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/AssistantsTable.kt @@ -0,0 +1,104 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.AssistantObject +import com.xebia.functional.openai.generated.model.CreateAssistantRequest +import com.xebia.functional.openai.generated.model.ListAssistantsResponse +import com.xebia.functional.openai.generated.model.ModifyAssistantRequest +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.assistantObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.modifiedAssistantObject +import com.xebia.functional.xef.server.db.tables.format +import java.util.* +import org.jetbrains.exposed.sql.* +import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction + +object AssistantsTable : Table("xef_assistants"), AssistantPersistence.Assistant { + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun create(createAssistantRequest: CreateAssistantRequest): AssistantObject = + transaction { + val uuid = UUID.randomUUID() + val assistantObject = + assistantObject(kotlinx.uuid.UUID(uuid.toString()), createAssistantRequest) + AssistantsTable.insert { + it[id] = uuid + it[data] = assistantObject + } + assistantObject + } + + override suspend fun get(assistantId: String): AssistantObject = transaction { + AssistantsTable.select(data) + .where { AssistantsTable.id eq UUID.fromString(assistantId) } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Assistant not found for id: $assistantId") + } + + override suspend fun delete(assistantId: String): Boolean = transaction { + AssistantsTable.deleteWhere { id eq UUID.fromString(assistantId) } > 0 + } + + override suspend fun list( + limit: Int?, + order: Assistants.OrderListAssistants?, + after: String?, + before: String? + ): ListAssistantsResponse { + return transaction { + val query = AssistantsTable.selectAll() + val sortedQuery = + when (order) { + Assistants.OrderListAssistants.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListAssistants.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterAssistant = after?.let { UUID.fromString(it) } + val beforeAssistant = before?.let { UUID.fromString(it) } + val afterAssistantIndex = + afterAssistant?.let { + sortedQuery.indexOfFirst { it[data].id == afterAssistant.toString() } + } + val beforeAssistantIndex = + beforeAssistant?.let { + sortedQuery.indexOfFirst { it[data].id == beforeAssistant.toString() } + } + val slicedQuery = + when { + afterAssistantIndex != null -> sortedQuery.drop(afterAssistantIndex + 1) + beforeAssistantIndex != null -> sortedQuery.take(beforeAssistantIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + ListAssistantsResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } + } + + override suspend fun modify( + assistantId: String, + modifyAssistantRequest: ModifyAssistantRequest + ): AssistantObject { + val assistantObject = get(assistantId) + return transaction { + val modifiedAssistantObject = modifiedAssistantObject(assistantObject, modifyAssistantRequest) + AssistantsTable.update({ AssistantsTable.id eq UUID.fromString(assistantId) }) { + it[data] = modifiedAssistantObject + } + modifiedAssistantObject + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesFilesTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesFilesTable.kt new file mode 100644 index 000000000..d46241877 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesFilesTable.kt @@ -0,0 +1,81 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.ListMessageFilesResponse +import com.xebia.functional.openai.generated.model.MessageFileObject +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.server.db.tables.format +import java.util.* +import org.jetbrains.exposed.sql.SortOrder +import org.jetbrains.exposed.sql.Table +import org.jetbrains.exposed.sql.and +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction + +object MessagesFilesTable : Table("xef_messages_files"), AssistantPersistence.MessageFile { + + val id = uuid("id") + val threadId = uuid("thread_id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun get(threadId: String, messageId: String, fileId: String): MessageFileObject { + return transaction { + MessagesFilesTable.select(data) + .where { + (data.extract("id") eq fileId) and + (MessagesFilesTable.threadId eq UUID.fromString(threadId)) and + (data.extract("messageId") eq messageId) + } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Message file not found for id: $fileId") + } + } + + override suspend fun list( + threadId: String, + messageId: String, + limit: Int?, + order: Assistants.OrderListMessageFiles?, + after: String?, + before: String? + ): ListMessageFilesResponse { + return transaction { + val query = + MessagesFilesTable.select(data).where { + (MessagesFilesTable.threadId eq UUID.fromString(threadId)) and + (data.extract("messageId") eq messageId) + } + val sortedQuery = + when (order) { + Assistants.OrderListMessageFiles.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListMessageFiles.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterMessage = after?.let { UUID.fromString(it) } + val beforeMessage = before?.let { UUID.fromString(it) } + val afterMessageIndex = + afterMessage?.let { sortedQuery.indexOfFirst { it[data].id == afterMessage.toString() } } + val beforeMessageIndex = + beforeMessage?.let { sortedQuery.indexOfFirst { it[data].id == beforeMessage.toString() } } + val slicedQuery = + when { + afterMessageIndex != null -> sortedQuery.drop(afterMessageIndex + 1) + beforeMessageIndex != null -> sortedQuery.take(beforeMessageIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + ListMessageFilesResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesTable.kt new file mode 100644 index 000000000..69d4d2e44 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/MessagesTable.kt @@ -0,0 +1,127 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.createMessageObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.modifiedMessageObject +import com.xebia.functional.xef.server.db.tables.format +import java.util.* +import kotlinx.serialization.json.JsonObject +import org.jetbrains.exposed.sql.* +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction + +object MessagesTable : Table("xef_messages"), AssistantPersistence.Message { + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListMessages?, + after: String?, + before: String? + ): ListMessagesResponse = transaction { + val query = MessagesTable.select(data).where { data.extract("threadId") eq threadId } + val sortedQuery = + when (order) { + Assistants.OrderListMessages.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListMessages.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterMessage = after?.let { UUID.fromString(it) } + val beforeMessage = before?.let { UUID.fromString(it) } + val afterMessageIndex = + afterMessage?.let { sortedQuery.indexOfFirst { it[data].id == afterMessage.toString() } } + val beforeMessageIndex = + beforeMessage?.let { sortedQuery.indexOfFirst { it[data].id == beforeMessage.toString() } } + val slicedQuery = + when { + afterMessageIndex != null -> sortedQuery.drop(afterMessageIndex + 1) + beforeMessageIndex != null -> sortedQuery.take(beforeMessageIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + ListMessagesResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } + + override suspend fun get(threadId: String, messageId: String): MessageObject = transaction { + MessagesTable.select(data) + .where { + (MessagesTable.id eq UUID.fromString(messageId)) and + (data.extract("threadId") eq threadId) + } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Message not found for id: $messageId") + } + + override suspend fun createMessage( + threadId: String, + assistantId: String, + runId: String, + content: String, + fileIds: List, + metadata: JsonObject?, + role: MessageObject.Role + ): MessageObject = transaction { + val uuid = UUID.randomUUID() + val msg = + createMessageObject( + kotlinx.uuid.UUID(uuid.toString()), + threadId, + role, + content, + assistantId, + runId, + fileIds, + metadata + ) + MessagesTable.insert { + it[id] = uuid + it[data] = msg + } + msg + } + + override suspend fun modify( + threadId: String, + messageId: String, + modifyMessageRequest: ModifyMessageRequest + ): MessageObject { + val messageObject = get(threadId, messageId) + return transaction { + val modifiedMessageObject = messageObject.copy(metadata = modifyMessageRequest.metadata) + MessagesTable.update({ MessagesTable.id eq UUID.fromString(messageId) }) { + it[data] = modifiedMessageObject + } + modifiedMessageObject + } + } + + override suspend fun updateContent( + threadId: String, + messageId: String, + content: String + ): MessageObject { + val messageObject = get(threadId, messageId) + return transaction { + val modifiedMessageObject = modifiedMessageObject(messageObject, content) + MessagesTable.update({ MessagesTable.id eq UUID.fromString(id) }) { + it[data] = modifiedMessageObject + } + modifiedMessageObject + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsStepsTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsStepsTable.kt new file mode 100644 index 000000000..c602d0836 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsStepsTable.kt @@ -0,0 +1,129 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.* +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.llm.assistants.local.GeneralAssistants +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.runStepObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.updatedRunStepObject +import com.xebia.functional.xef.server.db.tables.format +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID +import org.jetbrains.exposed.sql.* +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction + +object RunsStepsTable : Table("xef_runs"), AssistantPersistence.Step { + + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun get(threadId: String, runId: String, stepId: String): RunStepObject = + transaction { + RunsStepsTable.select(data) + .where { + (RunsStepsTable.id eq java.util.UUID.fromString(stepId)) and + (data.extract("runId") eq runId) and + (data.extract("threadId") eq threadId) + } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Run step not found for id: $stepId") + } + + override suspend fun list( + threadId: String, + runId: String, + limit: Int?, + order: Assistants.OrderListRunSteps?, + after: String?, + before: String? + ): ListRunStepsResponse { + return transaction { + val query = + RunsStepsTable.select(data).where { + (data.extract("runId") eq runId) and + (data.extract("threadId") eq threadId) + } + val sortedQuery = + when (order) { + Assistants.OrderListRunSteps.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListRunSteps.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterMessage = after?.let { java.util.UUID.fromString(it) } + val beforeMessage = before?.let { java.util.UUID.fromString(it) } + val afterMessageIndex = + afterMessage?.let { sortedQuery.indexOfFirst { it[data].id == afterMessage.toString() } } + val beforeMessageIndex = + beforeMessage?.let { sortedQuery.indexOfFirst { it[data].id == beforeMessage.toString() } } + val slicedQuery = + when { + afterMessageIndex != null -> sortedQuery.drop(afterMessageIndex + 1) + beforeMessageIndex != null -> sortedQuery.take(beforeMessageIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + ListRunStepsResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } + } + + override suspend fun create( + runObject: RunObject, + choice: GeneralAssistants.AssistantDecision, + toolCalls: List, + messageId: String? + ): RunStepObject { + return transaction { + val stepId = UUID.generateUUID() + val runStepObject = runStepObject(stepId, runObject, choice, toolCalls, messageId) + RunsStepsTable.insert { + it[id] = java.util.UUID.fromString(stepId.toString()) + it[data] = runStepObject + } + runStepObject + } + } + + suspend fun updateStatus( + threadId: String, + runId: String, + stepId: String, + status: RunStepObject.Status + ): RunStepObject { + val runStepObject = get(threadId, runId, stepId) + return transaction { + val updatedRunStepObject = runStepObject.copy(status = status) + RunsStepsTable.update({ RunsStepsTable.id eq java.util.UUID.fromString(stepId) }) { + it[data] = updatedRunStepObject + } + updatedRunStepObject + } + } + + override suspend fun updateToolsStep( + runObject: RunObject, + stepId: String, + stepCalls: + List + ): RunStepObject { + val runStepObject = get(runObject.threadId, runObject.id, stepId) + return transaction { + val updatedRunStepObject = updatedRunStepObject(runStepObject, stepCalls) + RunsStepsTable.update({ RunsStepsTable.id eq java.util.UUID.fromString(stepId) }) { + it[data] = updatedRunStepObject + } + updatedRunStepObject + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsTable.kt new file mode 100644 index 000000000..882c79fb4 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/RunsTable.kt @@ -0,0 +1,112 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.api.Assistants +import com.xebia.functional.openai.generated.model.CreateRunRequest +import com.xebia.functional.openai.generated.model.ListRunsResponse +import com.xebia.functional.openai.generated.model.ModifyRunRequest +import com.xebia.functional.openai.generated.model.RunObject +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.llm.assistants.local.GeneralAssistants +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.runObject +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.setRunToRequireToolOutouts +import com.xebia.functional.xef.server.db.tables.format +import kotlinx.uuid.UUID +import org.jetbrains.exposed.sql.SortOrder +import org.jetbrains.exposed.sql.Table +import org.jetbrains.exposed.sql.insert +import org.jetbrains.exposed.sql.json.extract +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction +import org.jetbrains.exposed.sql.update + +object RunsTable : Table("xef_runs"), AssistantPersistence.Run { + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun get(runId: String): RunObject = transaction { + RunsTable.select(data) + .where { RunsTable.id eq java.util.UUID.fromString(runId) } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Run not found for id: $runId") + } + + override suspend fun create(threadId: String, request: CreateRunRequest): RunObject { + val assistant = AssistantsTable.get(request.assistantId) + return transaction { + val uuid = java.util.UUID.randomUUID() + val runObject = runObject(UUID(uuid.toString()), threadId, request, assistant) + + RunsTable.insert { + it[id] = uuid + it[data] = runObject + } + runObject + } + } + + override suspend fun list( + threadId: String, + limit: Int?, + order: Assistants.OrderListRuns?, + after: String?, + before: String? + ): ListRunsResponse { + val query = RunsTable.select(data).where { data.extract("threadId") eq threadId } + val sortedQuery = + when (order) { + Assistants.OrderListRuns.asc -> + query.orderBy(data.extract("createdAt") to SortOrder.ASC) + Assistants.OrderListRuns.desc -> + query.orderBy(data.extract("createdAt") to SortOrder.DESC) + null -> query + } + val afterRun = after?.let { java.util.UUID.fromString(it) } + val beforeRun = before?.let { java.util.UUID.fromString(it) } + val afterRunIndex = + afterRun?.let { sortedQuery.indexOfFirst { it[data].id == afterRun.toString() } } + val beforeRunIndex = + beforeRun?.let { sortedQuery.indexOfFirst { it[data].id == beforeRun.toString() } } + val slicedQuery = + when { + afterRunIndex != null -> sortedQuery.drop(afterRunIndex + 1) + beforeRunIndex != null -> sortedQuery.take(beforeRunIndex) + else -> sortedQuery + } + val limitedQuery = limit?.let { slicedQuery.take(it) } ?: slicedQuery + return ListRunsResponse( + `object` = "list", + data = limitedQuery.map { it[data] }, + firstId = limitedQuery.firstOrNull()?.get(data)?.id, + lastId = limitedQuery.lastOrNull()?.get(data)?.id, + hasMore = sortedQuery.count() > limitedQuery.count() + ) + } + + override suspend fun modify(runId: String, modifyRunRequest: ModifyRunRequest): RunObject { + val runObject = get(runId) + return transaction { + val modifiedRunObject = + runObject.copy(metadata = modifyRunRequest.metadata ?: runObject.metadata) + RunsTable.update({ RunsTable.id eq java.util.UUID.fromString(runId) }) { + it[data] = modifiedRunObject + } + modifiedRunObject + } + } + + override suspend fun updateRunToRequireToolOutputs( + id: String, + selectedTools: GeneralAssistants.SelectedTool + ): RunObject { + val runObject = get(id) + return transaction { + val modifiedRunObject = setRunToRequireToolOutouts(runObject, selectedTools) + RunsTable.update({ RunsTable.id eq java.util.UUID.fromString(id) }) { + it[data] = modifiedRunObject + } + modifiedRunObject + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/ThreadsTable.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/ThreadsTable.kt new file mode 100644 index 000000000..a0928b105 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/assistants/postgres/tables/ThreadsTable.kt @@ -0,0 +1,74 @@ +package com.xebia.functional.xef.server.assistants.postgres.tables + +import com.xebia.functional.openai.generated.model.CreateThreadRequest +import com.xebia.functional.openai.generated.model.ModifyThreadRequest +import com.xebia.functional.openai.generated.model.ThreadObject +import com.xebia.functional.xef.llm.assistants.local.AssistantPersistence +import com.xebia.functional.xef.server.assistants.utils.AssistantUtils.threadObject +import com.xebia.functional.xef.server.db.tables.format +import java.util.* +import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq +import org.jetbrains.exposed.sql.Table +import org.jetbrains.exposed.sql.deleteWhere +import org.jetbrains.exposed.sql.insert +import org.jetbrains.exposed.sql.json.jsonb +import org.jetbrains.exposed.sql.transactions.transaction +import org.jetbrains.exposed.sql.update + +object ThreadsTable : Table("xef_threads"), AssistantPersistence.Thread { + val id = uuid("id") + val data = jsonb("data", format) + + override val primaryKey = PrimaryKey(id) + + override suspend fun create( + assistantId: String?, + runId: String?, + createThreadRequest: CreateThreadRequest + ): ThreadObject { + val uuid = UUID.randomUUID() + val threadObject = threadObject(kotlinx.uuid.UUID(uuid.toString()), createThreadRequest) + transaction { + ThreadsTable.insert { + it[id] = uuid + it[data] = threadObject + } + } + createThreadRequest.messages.orEmpty().forEach { + MessagesTable.createUserMessage( + threadId = threadObject.id, + assistantId = assistantId, + runId = runId, + createMessageRequest = it + ) + } + return threadObject + } + + override suspend fun get(threadId: String): ThreadObject { + return transaction { + ThreadsTable.select(data) + .where { ThreadsTable.id eq UUID.fromString(threadId) } + .singleOrNull() + ?.let { it[data] } ?: throw Exception("Thread not found for id: $threadId") + } + } + + override suspend fun delete(threadId: String): Boolean = transaction { + ThreadsTable.deleteWhere { id eq UUID.fromString(threadId) } > 0 + } + + override suspend fun modify( + threadId: String, + modifyThreadRequest: ModifyThreadRequest + ): ThreadObject { + val threadObject = get(threadId) + return transaction { + val modifiedThreadObject = threadObject.copy(metadata = modifyThreadRequest.metadata) + ThreadsTable.update({ ThreadsTable.id eq UUID.fromString(threadId) }) { + it[data] = modifiedThreadObject + } + modifiedThreadObject + } + } +} diff --git a/server/src/main/resources/db/migrations/psql/V2__Assistants.sql b/server/src/main/resources/db/migrations/psql/V2__Assistants.sql new file mode 100644 index 000000000..76f446e0c --- /dev/null +++ b/server/src/main/resources/db/migrations/psql/V2__Assistants.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS xef_assistants( + id UUID PRIMARY KEY, + data JSONB NOT NULL +); + +CREATE TABLE IF NOT EXISTS xef_assistants_files( + id UUID PRIMARY KEY, + data JSONB NOT NULL +); + +CREATE TABLE IF NOT EXISTS xef_messages( + id UUID PRIMARY KEY, + data JSONB NOT NULL +); + +CREATE TABLE IF NOT EXISTS xef_messages_files( + id UUID PRIMARY KEY, + thread_id UUID NOT NULL, + data JSONB NOT NULL +); + +CREATE TABLE IF NOT EXISTS xef_runs( + id UUID PRIMARY KEY, + data JSONB NOT NULL +); + +CREATE TABLE IF NOT EXISTS xef_threads( + id UUID PRIMARY KEY, + data JSONB NOT NULL +); diff --git a/server/src/main/resources/logback.xml b/server/src/main/resources/logback.xml new file mode 100644 index 000000000..fb058d664 --- /dev/null +++ b/server/src/main/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n + + + + + + + + + + + + + + + diff --git a/server/web/src/components/Pages/Chat/Chat.tsx b/server/web/src/components/Pages/Chat/Chat.tsx index 4655f743d..174e45d90 100644 --- a/server/web/src/components/Pages/Chat/Chat.tsx +++ b/server/web/src/components/Pages/Chat/Chat.tsx @@ -38,10 +38,15 @@ export function Chat({initialMessages = []}: { setMessages(prevState => [...prevState, {role: 'user', content: prompt}]); const client = openai(settings); + //add header for json content type const completion = await client.chat.completions.create({ messages: [{role: 'user', content: prompt}], model: 'gpt-3.5-turbo-16k', stream: true + }, { + headers: { + 'Content-Type': 'application/json' + } }); let currentAssistantMessage = ''; // Create a local variable to accumulate the message diff --git a/server/web/src/utils/api/chatCompletions.ts b/server/web/src/utils/api/chatCompletions.ts index 1b5160660..cd1e04a55 100644 --- a/server/web/src/utils/api/chatCompletions.ts +++ b/server/web/src/utils/api/chatCompletions.ts @@ -1,6 +1,6 @@ -import { - defaultApiServer, -} from '@/utils/api'; +// import { +// defaultApiServer, +// } from '@/utils/api'; import {OpenAI} from "openai/index"; import {Settings} from "@/state/Settings"; @@ -8,7 +8,7 @@ import {Settings} from "@/state/Settings"; export function openai (settings: Settings): OpenAI { if (!settings.apiKey) throw 'API key not set'; return new OpenAI({ - baseURL: defaultApiServer, + //baseURL: defaultApiServer, // TODO: remove this when the key is the user token used for client auth dangerouslyAllowBrowser: true, apiKey: settings.apiKey, // defaults to process.env["OPENAI_API_KEY"] diff --git a/settings.gradle.kts b/settings.gradle.kts index dbb1da79e..95c386a62 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -63,6 +63,9 @@ include("xef-evaluator") project(":xef-evaluator").projectDir = file("evaluator") +include("xef-aws-bedrock") +project(":xef-aws-bedrock").projectDir = file("integrations/aws/bedrock") + // include("xef-server") project(":xef-server").projectDir = file("server")