Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ kotlin {
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
implementation(libs.ollama.testcontainers)
implementation(libs.junit.jupiter.api)
implementation(libs.junit.jupiter.engine)
}
}
val linuxX64Main by getting {
Expand Down
4 changes: 3 additions & 1 deletion core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestMo
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.ToolCallStrategy
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
Expand Down Expand Up @@ -108,10 +109,11 @@ sealed interface AI {
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125,
toolCallStrategy: ToolCallStrategy = ToolCallStrategy.Supported,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(Prompt(model, prompt), target, config, api, conversation)
): A = chat(Prompt(model, toolCallStrategy, prompt), target, config, api, conversation)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
Expand Down
18 changes: 17 additions & 1 deletion core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptMessage
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.ToolCallStrategy
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.single
import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*

Expand All @@ -29,7 +33,18 @@ data class DefaultAI<A : Any>(
@Serializable data class Value<A>(val value: A)

private suspend fun <B> runWithSerializer(prompt: Prompt, serializer: KSerializer<B>): B =
api.prompt(prompt, conversation, serializer)
when (prompt.toolCallStrategy) {
ToolCallStrategy.Supported -> api.prompt(prompt, conversation, serializer)
else ->
runStreamingWithFunctionSerializer(prompt, serializer)
.mapNotNull {
when (it) {
is StreamedFunction.Property -> null
is StreamedFunction.Result -> it.value
}
}
.single()
}

private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> =
api.promptStreaming(prompt, conversation)
Expand All @@ -49,6 +64,7 @@ data class DefaultAI<A : Any>(
suspend operator fun invoke(prompt: Prompt): A {
val serializer = serializer()
return when (serializer.descriptor.kind) {
PrimitiveKind.STRING -> api.promptMessage(prompt, conversation) as A
SerialKind.ENUM -> {
runWithEnumSingleTokenSerializer(serializer, prompt)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionResponseC
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.PromptCalculator.adaptPromptToConversationAndModel
import com.xebia.functional.xef.llm.models.MessageWithUsage
import com.xebia.functional.xef.llm.models.MessagesUsage
import com.xebia.functional.xef.llm.models.MessagesWithUsage
Expand All @@ -18,7 +19,7 @@ import kotlinx.coroutines.flow.*
@AiDsl
fun Chat.promptStreaming(prompt: Prompt, scope: Conversation = Conversation()): Flow<String> =
flow {
val messagesForRequestPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope)
val messagesForRequestPrompt = prompt.adaptPromptToConversationAndModel(scope)

val request =
CreateChatCompletionRequest(
Expand All @@ -42,14 +43,13 @@ fun Chat.promptStreaming(prompt: Prompt, scope: Conversation = Conversation()):
}
content
}
.onEach { emit(it) }
.onCompletion {
val aiResponseMessage = PromptBuilder.assistant(buffer.toString())
val newMessages = prompt.messages + listOf(aiResponseMessage)
newMessages.addToMemory(scope, prompt.configuration.messagePolicy.addMessagesToConversation)
buffer.clear()
}
.collect()
.collect { emit(it) }
}

@AiDsl
Expand Down Expand Up @@ -88,7 +88,7 @@ private suspend fun <T> Chat.promptResponse(
): Pair<List<T>, CreateChatCompletionResponse> =
scope.metric.promptSpan(prompt) {
val promptMemories: List<Memory> = prompt.messages.toMemory(scope)
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope)
val adaptedPrompt = prompt.adaptPromptToConversationAndModel(scope)

adaptedPrompt.addMetrics(scope)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import arrow.core.raise.catch
import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.model.*
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.PromptCalculator.adaptPromptToConversationAndModel
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
Expand All @@ -20,14 +23,16 @@ import kotlinx.serialization.json.*
@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): FunctionObject {
val fnName = descriptor.serialName.substringAfterLast(".")
return chatFunction(fnName, buildJsonSchema(descriptor))
val description =
descriptor.annotations.firstOrNull { it is Description }?.let { it as Description }?.value
return chatFunction(fnName, description, buildJsonSchema(descriptor))
}

fun chatFunctions(descriptors: List<SerialDescriptor>): List<FunctionObject> =
descriptors.map(::chatFunction)

fun chatFunction(fnName: String, schema: JsonObject): FunctionObject =
FunctionObject(fnName, "Generated function for $fnName", schema)
fun chatFunction(fnName: String, description: String?, schema: JsonObject): FunctionObject =
FunctionObject(fnName, description ?: "Generated function for $fnName", schema)

@AiDsl
suspend fun <A> Chat.prompt(
Expand All @@ -36,7 +41,7 @@ suspend fun <A> Chat.prompt(
serializer: KSerializer<A>,
): A =
prompt(prompt, scope, chatFunctions(listOf(serializer.descriptor))) { call ->
Json.decodeFromString(serializer, call.arguments)
Config.DEFAULT.json.decodeFromString(serializer, call.arguments)
}

@OptIn(ExperimentalSerializationApi::class)
Expand All @@ -49,15 +54,16 @@ suspend fun <A> Chat.prompt(
): A =
prompt(prompt, scope, chatFunctions(descriptors)) { call ->
// adds a `type` field with the call.functionName serial name equivalent to the call arguments
val jsonWithDiscriminator = Json.decodeFromString(JsonElement.serializer(), call.arguments)
val jsonWithDiscriminator =
Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments)
val descriptor =
descriptors.firstOrNull { it.serialName.endsWith(call.functionName) }
?: error("No descriptor found for ${call.functionName}")
val newJson =
JsonObject(
jsonWithDiscriminator.jsonObject + ("type" to JsonPrimitive(descriptor.serialName))
)
Json.decodeFromString(serializer, Json.encodeToString(newJson))
Config.DEFAULT.json.decodeFromString(serializer, Config.DEFAULT.json.encodeToString(newJson))
}

@AiDsl
Expand All @@ -67,7 +73,7 @@ fun <A> Chat.promptStreaming(
serializer: KSerializer<A>,
): Flow<StreamedFunction<A>> =
promptStreaming(prompt, scope, chatFunction(serializer.descriptor)) { json ->
Json.decodeFromString(serializer, json)
Config.DEFAULT.json.decodeFromString(serializer, json)
}

@AiDsl
Expand All @@ -79,8 +85,7 @@ suspend fun <A> Chat.prompt(
): A =
scope.metric.promptSpan(prompt) {
val promptWithFunctions = prompt.copy(functions = functions)
val adaptedPrompt =
PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope)
val adaptedPrompt = promptWithFunctions.adaptPromptToConversationAndModel(scope)
adaptedPrompt.addMetrics(scope)
val request = createChatCompletionRequest(adaptedPrompt)
tryDeserialize(serializer, promptWithFunctions.configuration.maxDeserializationAttempts) {
Expand Down Expand Up @@ -139,7 +144,7 @@ fun <A> Chat.promptStreaming(
serializer: (json: String) -> A,
): Flow<StreamedFunction<A>> = flow {
val promptWithFunctions = prompt.copy(functions = listOf(function))
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope)
val adaptedPrompt = promptWithFunctions.adaptPromptToConversationAndModel(scope)

val request = createChatCompletionRequest(adaptedPrompt).copy(stream = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import com.xebia.functional.xef.store.Memory

internal object PromptCalculator {

suspend fun adaptPromptToConversationAndModel(prompt: Prompt, scope: Conversation): Prompt =
when (prompt.configuration.messagePolicy.addMessagesFromConversation) {
MessagesFromHistory.ALL -> adaptPromptFromConversation(prompt, scope)
MessagesFromHistory.NONE -> prompt
suspend fun Prompt.adaptPromptToConversationAndModel(scope: Conversation): Prompt =
when (configuration.messagePolicy.addMessagesFromConversation) {
MessagesFromHistory.ALL -> adaptPromptFromConversation(this, scope)
MessagesFromHistory.NONE -> this
}

private suspend fun adaptPromptFromConversation(prompt: Prompt, scope: Conversation): Prompt {
Expand Down
Loading