Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions bot/engine/src/main/kotlin/definition/AsyncStoryHandlerBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ package ai.tock.bot.definition

import ai.tock.bot.engine.AsyncBotBus
import ai.tock.bot.engine.AsyncBus
import ai.tock.bot.engine.Bot
import ai.tock.bot.engine.BotBus
import ai.tock.shared.InternalTockApi
import ai.tock.shared.coroutines.ExperimentalTockCoroutines
import ai.tock.shared.defaultNamespace
import ai.tock.translator.I18nKeyProvider.Companion.generateKey
import ai.tock.translator.I18nLabelValue
import ai.tock.translator.I18nLocalizedLabel
import kotlinx.coroutines.runBlocking
import mu.KotlinLogging

/**
Expand All @@ -41,8 +44,15 @@ abstract class AsyncStoryHandlerBase(
override var i18nNamespace: String = defaultNamespace
@InternalTockApi set

@Deprecated("Use coroutines to call this interface", replaceWith = ReplaceWith("handle(asyncBus)"))
override fun handle(bus: BotBus) {
runBlocking(
Bot.handlerCoroutineName { findStoryDefinition(bus)?.id ?: mainIntent?.name ?: "??" },
) { handle(AsyncBotBus(bus)) }
}

override suspend fun handle(bus: AsyncBus) {
val baseBus = (bus as AsyncBotBus).botBus
val baseBus = (bus as AsyncBotBus).syncBus
val storyDefinition = findStoryDefinition(bus)
// if not supported user interface, use unknown
if (storyDefinition?.unsupportedUserInterfaces?.contains(bus.userInterfaceType) == true) {
Expand All @@ -63,12 +73,18 @@ abstract class AsyncStoryHandlerBase(

protected abstract suspend fun action(bus: AsyncBus)

protected fun AsyncBus.isEndCalled() = StoryHandlerBase.isEndCalled((this as AsyncBotBus).botBus)
protected fun AsyncBus.isEndCalled() = StoryHandlerBase.isEndCalled((this as AsyncBotBus).syncBus)

/**
* Finds the story definition of this handler.
*/
open fun findStoryDefinition(bus: AsyncBus): StoryDefinition? = (bus as AsyncBotBus).botBus.botDefinition.findStoryByStoryHandler(this, bus.connectorId)
open fun findStoryDefinition(bus: AsyncBus): StoryDefinition? {
return findStoryDefinition((bus as AsyncBotBus).syncBus)
}

private fun findStoryDefinition(bus: BotBus): StoryDefinition? {
return bus.botDefinition.findStoryByStoryHandler(this, bus.connectorId)
}

/**
* Story i18n category.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
* limitations under the License.
*/

package ai.tock.bot.definition.definition
package ai.tock.bot.definition

import ai.tock.bot.definition.StoryHandler
import ai.tock.bot.definition.StoryHandlerListener
import ai.tock.bot.engine.AsyncBotBus
import ai.tock.bot.engine.AsyncBus
import ai.tock.bot.engine.BotBus
Expand Down
107 changes: 55 additions & 52 deletions bot/engine/src/main/kotlin/engine/AsyncBotBus.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import kotlin.coroutines.CoroutineContext
import kotlin.reflect.safeCast

@ExperimentalTockCoroutines
open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
open class AsyncBotBus(val syncBus: BotBus) : AsyncBus {
companion object {
/**
* Helper method to retrieve the current bus,
Expand All @@ -61,25 +61,28 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
suspend fun retrieveCurrentBus(): AsyncBotBus? = currentCoroutineContext()[Ref]?.bus
}

@Deprecated("Use syncBus instead", ReplaceWith("syncBus"))
val botBus get() = syncBus

private val executor: Executor get() = injector.provide()
private val featureDao: FeatureDAO get() = injector.provide()

override val connectorId: String
get() = botBus.connectorId
get() = syncBus.connectorId
override val targetConnectorType: ConnectorType
get() = botBus.targetConnectorType
get() = syncBus.targetConnectorType
override val botId: PlayerId
get() = botBus.botId
get() = syncBus.botId
override val userId: PlayerId
get() = botBus.userId
get() = syncBus.userId
override val userLocale: Locale
get() = botBus.userLocale
get() = syncBus.userLocale
override val userInterfaceType: UserInterfaceType
get() = botBus.userInterfaceType
get() = syncBus.userInterfaceType
override val intent: IntentAware?
get() = botBus.intent
get() = syncBus.intent
override val currentIntent: IntentAware?
get() = botBus.currentIntent
get() = syncBus.currentIntent
override val currentStoryDefinition: StoryDefinition
get() = story.definition
override var step: AsyncStoryStep<*>?
Expand All @@ -88,90 +91,90 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
story.step = step?.name
}
override val userInfo: UserPreferences
get() = botBus.userPreferences
get() = syncBus.userPreferences
override val userState: UserState
get() = botBus.userTimeline.userState
val story: Story get() = botBus.story
get() = syncBus.userTimeline.userState
val story: Story get() = syncBus.story

override fun defaultAnswerDelay() = botBus.defaultDelay(botBus.currentAnswerIndex)
override fun defaultAnswerDelay() = syncBus.defaultDelay(syncBus.currentAnswerIndex)

override suspend fun constrainNlp(nextActionState: NextUserActionState) {
botBus.nextUserActionState = nextActionState
syncBus.nextUserActionState = nextActionState
}

override fun choice(key: ParameterKey): String? {
return botBus.choice(key)
return syncBus.choice(key)
}

override fun booleanChoice(key: ParameterKey): Boolean {
return botBus.booleanChoice(key)
return syncBus.booleanChoice(key)
}

override fun hasActionEntity(role: String): Boolean {
return botBus.hasActionEntity(role)
return syncBus.hasActionEntity(role)
}

override fun <T : Value> entityValue(
role: String,
valueTransformer: (EntityValue) -> T?,
): T? {
return synchronized(botBus) { botBus.entityValue(role, valueTransformer) }
return synchronized(syncBus) { syncBus.entityValue(role, valueTransformer) }
}

override fun entityValueDetails(role: String): EntityValue? {
return synchronized(botBus) { botBus.entityValueDetails(role) }
return synchronized(syncBus) { syncBus.entityValueDetails(role) }
}

override fun changeEntityValue(
role: String,
newValue: EntityValue?,
) {
synchronized(botBus) { botBus.changeEntityValue(role, newValue) }
synchronized(syncBus) { syncBus.changeEntityValue(role, newValue) }
}

override fun changeEntityValue(
entity: Entity,
newValue: Value?,
) {
return synchronized(botBus) { botBus.changeEntityValue(entity, newValue) }
return synchronized(syncBus) { syncBus.changeEntityValue(entity, newValue) }
}

override fun removeAllEntityValues() {
// Synchronized to avoid ConcurrentModificationException with other entity setters
synchronized(botBus) {
botBus.removeAllEntityValues()
synchronized(syncBus) {
syncBus.removeAllEntityValues()
}
}

override fun <T : Any> getContextValue(key: DialogContextKey<T>): T? {
return botBus.dialog.state.context[key.name]?.let(key.type::safeCast)
return syncBus.dialog.state.context[key.name]?.let(key.type::safeCast)
}

override fun <T : Any> setContextValue(
key: DialogContextKey<T>,
value: T?,
) {
botBus.dialog.state.setContextValue(key, value)
syncBus.dialog.state.setContextValue(key, value)
}

override fun <T : Any> setBusContextValue(
key: DialogContextKey<T>,
value: T?,
) {
botBus.setBusContextValue(key.name, value)
syncBus.setBusContextValue(key.name, value)
}

override fun <T : Any> getBusContextValue(key: DialogContextKey<T>): T? {
return botBus.getBusContextValue<Any?>(key.name)?.let(key.type::safeCast)
return syncBus.getBusContextValue<Any?>(key.name)?.let(key.type::safeCast)
}

override suspend fun isFeatureEnabled(
feature: FeatureType,
default: Boolean,
): Boolean =
featureDao.isEnabled(
botBus.botDefinition.botId,
botBus.botDefinition.namespace,
syncBus.botDefinition.botId,
syncBus.botDefinition.namespace,
feature,
connectorId,
default,
Expand All @@ -183,21 +186,21 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
starterIntent: Intent,
step: StoryStepDef?,
) {
synchronized(botBus) {
botBus.stepDef = step
botBus.switchStory(storyDefinition, starterIntent)
botBus.hasCurrentSwitchStoryProcess = false
synchronized(syncBus) {
syncBus.stepDef = step
syncBus.switchStory(storyDefinition, starterIntent)
syncBus.hasCurrentSwitchStoryProcess = false
}
(storyDefinition.storyHandler as? AsyncStoryHandler)?.handle(this)
?: storyDefinition.storyHandler.handle(botBus)
?: storyDefinition.storyHandler.handle(syncBus)
}

override fun i18nWithKey(
key: String,
defaultLabel: String,
vararg args: Any?,
): I18nLabelValue {
return botBus.i18nKey(key, defaultLabel, *args)
return syncBus.i18nKey(key, defaultLabel, *args)
}

override fun i18nWithKey(
Expand All @@ -206,23 +209,23 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
defaultI18n: Set<I18nLocalizedLabel>,
vararg args: Any?,
): I18nLabelValue {
return botBus.i18nKey(key, defaultLabel, defaultI18n, *args)
return syncBus.i18nKey(key, defaultLabel, defaultI18n, *args)
}

override fun i18n(
defaultLabel: CharSequence,
args: List<Any?>,
): I18nLabelValue {
return botBus.i18n(defaultLabel, args)
return syncBus.i18n(defaultLabel, args)
}

override suspend fun send(
i18nText: CharSequence,
delay: Long,
) {
withContext(executor.asCoroutineDispatcher()) {
synchronized(botBus) {
botBus.send(i18nText, delay = delay)
synchronized(syncBus) {
syncBus.send(i18nText, delay = delay)
}
}
}
Expand All @@ -232,8 +235,8 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
vararg i18nArgs: Any?,
) {
withContext(executor.asCoroutineDispatcher()) {
synchronized(botBus) {
botBus.send(i18nText, *i18nArgs)
synchronized(syncBus) {
syncBus.send(i18nText, *i18nArgs)
}
}
}
Expand All @@ -243,8 +246,8 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
delay: Long,
) {
withContext(executor.asCoroutineDispatcher()) {
synchronized(botBus) {
botBus.end(i18nText, delay = delay)
synchronized(syncBus) {
syncBus.end(i18nText, delay = delay)
}
}
}
Expand All @@ -254,8 +257,8 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
vararg i18nArgs: Any?,
) {
withContext(executor.asCoroutineDispatcher()) {
synchronized(botBus) {
botBus.end(i18nText, *i18nArgs)
synchronized(syncBus) {
syncBus.end(i18nText, *i18nArgs)
}
}
}
Expand All @@ -265,11 +268,11 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
messageProvider: Bus<*>.() -> Any?,
) {
val messages = toMessageList(messageProvider)
synchronized(botBus) {
synchronized(syncBus) {
if (messages.messages.isEmpty()) {
botBus.send(delay)
syncBus.send(delay)
} else {
botBus.send(messages, delay)
syncBus.send(messages, delay)
}
}
}
Expand All @@ -279,11 +282,11 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
messageProvider: Bus<*>.() -> Any?,
) {
val messages = toMessageList(messageProvider)
synchronized(botBus) {
synchronized(syncBus) {
if (messages.messages.isEmpty()) {
botBus.end(delay)
syncBus.end(delay)
} else {
botBus.end(messages, delay)
syncBus.end(messages, delay)
}
}
}
Expand All @@ -292,7 +295,7 @@ open class AsyncBotBus(val botBus: BotBus) : AsyncBus {
// calls to `translate` are blocking (database and possibly translator API),
// so we ensure they are done in a worker thread
withContext(executor.asCoroutineDispatcher()) {
toMessageList(null, botBus, messageProvider)
toMessageList(null, syncBus, messageProvider)
}

data class Ref(val bus: AsyncBotBus) : CoroutineContext.Element {
Expand Down
16 changes: 15 additions & 1 deletion bot/engine/src/main/kotlin/engine/Bot.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ import ai.tock.bot.engine.feature.DefaultFeatureType
import ai.tock.bot.engine.nlp.NlpController
import ai.tock.bot.engine.user.UserTimeline
import ai.tock.shared.coroutines.ExperimentalTockCoroutines
import ai.tock.shared.devEnvironment
import ai.tock.shared.injector
import com.github.salomonbrys.kodein.instance
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.withContext
import mu.KotlinLogging
import java.util.Locale
import kotlin.coroutines.EmptyCoroutineContext

/**
*
Expand All @@ -54,6 +57,13 @@ internal class Bot(botDefinitionBase: BotDefinition, val configuration: BotAppli
* (warning: advanced usage only).
*/
internal fun retrieveCurrentBus(): BotBus? = currentBus.get()

internal inline fun handlerCoroutineName(storyId: () -> String) =
if (devEnvironment) {
CoroutineName("handler(${storyId()})")
} else {
EmptyCoroutineContext
}
}

private val logger = KotlinLogging.logger {}
Expand Down Expand Up @@ -127,7 +137,11 @@ internal class Bot(botDefinitionBase: BotDefinition, val configuration: BotAppli
val bus = TockBotBus(connector, userTimeline, dialog, action, connectorData, botDefinition)
val asyncBus = AsyncBotBus(bus)

withContext(AsyncBotBus.Ref(asyncBus) + currentBus.asContextElement(bus)) {
withContext(
AsyncBotBus.Ref(asyncBus) +
currentBus.asContextElement(bus) +
handlerCoroutineName(story.definition::id),
) {
val closeMessageQueue = bus.deferMessageSending(this)

if (asyncBus.isFeatureEnabled(DefaultFeatureType.DISABLE_BOT)) {
Expand Down
Loading
Loading