Skip to content

Commit a853eba

Browse files
Pazuzzuvsct-jburetPazuzzu
authored
resolves #1881 NLP: add healthcheck for Sagemaker classifiers (#1882)
* resolves #1881 NLP: add healthcheck for Sagemaker classifiers * resolves #1881 NLP: address reviews * resolves #1881 NLP: fix case, where no intialized sagemaker client returns OK * resolves #1881 NLP: address client buiding reviews --------- Co-authored-by: Julien Buret <jburet@voyages-sncf.com> Co-authored-by: Pazuzzu <bogaforlife@gmai.com>
1 parent 074eeb7 commit a853eba

File tree

16 files changed

+142
-8
lines changed

16 files changed

+142
-8
lines changed

nlp/api/service/src/main/kotlin/NlpVerticle.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package ai.tock.nlp.api
1818

19+
/**
20+
*
21+
*/
22+
// Add this import at the top with the other imports
1923
import ai.tock.nlp.front.client.FrontClient
2024
import ai.tock.nlp.front.service.UnknownApplicationException
2125
import ai.tock.nlp.front.shared.codec.ApplicationDump
@@ -28,6 +32,7 @@ import ai.tock.nlp.front.shared.merge.ValuesMergeQuery
2832
import ai.tock.nlp.front.shared.monitoring.MarkAsUnknownQuery
2933
import ai.tock.nlp.front.shared.monitoring.ParseRequestLogCountQuery
3034
import ai.tock.nlp.front.shared.parser.ParseQuery
35+
import ai.tock.nlp.model.service.NlpClassifierService
3136
import ai.tock.shared.Executor
3237
import ai.tock.shared.TOCK_FRONT_DATABASE
3338
import ai.tock.shared.TOCK_MODEL_DATABASE
@@ -46,9 +51,6 @@ import mu.KotlinLogging
4651
import org.litote.kmongo.Id
4752
import java.util.Locale
4853

49-
/**
50-
*
51-
*/
5254
class NlpVerticle : WebVerticle() {
5355

5456
private val protectPath = verticleBooleanProperty("tock_nlp_protect_path", false)
@@ -59,6 +61,9 @@ class NlpVerticle : WebVerticle() {
5961

6062
private val executor: Executor by injector.instance()
6163

64+
// Add this line to access NlpClassifierService
65+
private val nlpClassifierService = NlpClassifierService
66+
6267
override val logger: KLogger = KotlinLogging.logger {}
6368

6469
override fun authProvider(): TockAuthProvider? {
@@ -232,7 +237,7 @@ class NlpVerticle : WebVerticle() {
232237
Pair("duckling_service", { FrontClient.healthcheck() }),
233238
Pair("tock_front_database", { pingMongoDatabase(TOCK_FRONT_DATABASE) }),
234239
Pair("tock_model_database", { pingMongoDatabase(TOCK_MODEL_DATABASE) })
235-
)
240+
) + NlpClassifierService.healthcheck()
236241
)
237242
}
238243

nlp/core/service/src/main/kotlin/NlpCoreService.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package ai.tock.nlp.core.service
1818

19+
1920
import ai.tock.nlp.core.CallContext
2021
import ai.tock.nlp.core.Entity
2122
import ai.tock.nlp.core.EntityRecognition
@@ -40,8 +41,8 @@ import ai.tock.nlp.model.ModelNotInitializedException
4041
import ai.tock.nlp.model.NlpClassifier
4142
import ai.tock.shared.checkMaxLengthAllowed
4243
import ai.tock.shared.error
43-
import ai.tock.shared.normalize
4444
import ai.tock.shared.injector
45+
import ai.tock.shared.normalize
4546
import com.github.salomonbrys.kodein.instance
4647
import mu.KotlinLogging
4748

nlp/model/sagemaker/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
<groupId>software.amazon.awssdk</groupId>
3737
<artifactId>sagemakerruntime</artifactId>
3838
</dependency>
39+
<dependency>
40+
<groupId>software.amazon.awssdk</groupId>
41+
<artifactId>sagemaker</artifactId>
42+
</dependency>
3943
<dependency>
4044
<groupId>software.amazon.awssdk</groupId>
4145
<artifactId>sts</artifactId>

nlp/model/sagemaker/src/main/kotlin/SagemakerAwsClient.kt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@ package ai.tock.nlp.sagemaker
1717

1818
import ai.tock.shared.jackson.mapper
1919
import com.fasterxml.jackson.module.kotlin.readValue
20+
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider
2021
import software.amazon.awssdk.core.SdkBytes
22+
import software.amazon.awssdk.services.sagemaker.SageMakerClient
23+
import software.amazon.awssdk.services.sagemaker.model.DescribeEndpointRequest
2124
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClient
2225
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest
2326
import java.nio.charset.Charset
2427

2528
class SagemakerAwsClient(private val configuration: SagemakerAwsClientProperties) {
2629

30+
val name = configuration.name
31+
2732
// for intentions and entities
2833
data class ParsedRequest(
2934
val text: String
@@ -56,6 +61,10 @@ class SagemakerAwsClient(private val configuration: SagemakerAwsClientProperties
5661
.region(configuration.region)
5762
.build()
5863

64+
private val sagemakerClient: SageMakerClient = SageMakerClient.builder()
65+
.region(configuration.region)
66+
.build()
67+
5968
fun parseIntent(request: ParsedRequest) = invokeSageMakerIntentEndpoint(request.text)
6069

6170
fun parseEntities(request: ParsedRequest): ParsedEntitiesResponse = invokeSageMakerEntitiesEndpoint(request.text)
@@ -84,4 +93,12 @@ class SagemakerAwsClient(private val configuration: SagemakerAwsClientProperties
8493
val entities = mapper.readValue<List<ParsedEntity>>(response.body().asInputStream())
8594
return ParsedEntitiesResponse(entities)
8695
}
96+
97+
fun healthcheck(): Boolean {
98+
val endpointRequest = DescribeEndpointRequest.builder()
99+
.endpointName(configuration.endpointName)
100+
.build()
101+
val response = sagemakerClient.describeEndpoint(endpointRequest)
102+
return response.endpointStatus() == software.amazon.awssdk.services.sagemaker.model.EndpointStatus.IN_SERVICE
103+
}
87104
}

nlp/model/sagemaker/src/main/kotlin/SagemakerAwsClientProperties.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import software.amazon.awssdk.regions.Region
2121
internal fun String.unescapeSagemakerName(): String = replace("___", ":")
2222

2323
data class SagemakerAwsClientProperties(
24+
val name: String,
2425
val region: Region,
2526
val endpointName: String,
2627
val contentType: String,

nlp/model/sagemaker/src/main/kotlin/SagemakerClientProvider.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@ internal object SagemakerClientProvider {
2222

2323
fun getClient(conf: SagemakerAwsClientProperties): SagemakerAwsClient =
2424
clientMap.getOrPut(conf) { SagemakerAwsClient(conf) }
25+
26+
fun getAllClient(): MutableCollection<SagemakerAwsClient> =
27+
clientMap.values
2528
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (C) 2017/2024 e-voyageurs technologies
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package ai.tock.nlp.sagemaker
17+
18+
/**
19+
* Enum representing the different types of Sagemaker clients.
20+
*/
21+
enum class SagemakerClientType(val clientName: String) {
22+
INTENT_CLASSIFICATION("intent-classification"),
23+
ENTITY_CLASSIFICATION("entity-classification");
24+
25+
override fun toString(): String = clientName
26+
}

nlp/model/sagemaker/src/main/kotlin/SagemakerEngineProvider.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package ai.tock.nlp.sagemaker
1818

19+
import NlpHealthcheckResult
1920
import ai.tock.nlp.core.NlpEngineType
2021
import ai.tock.nlp.model.TokenizerContext
2122
import ai.tock.nlp.model.service.engine.EntityClassifier
@@ -56,4 +57,24 @@ class SagemakerEngineProvider : NlpEngineProvider {
5657
// do not tokenize anything at this stage
5758
override fun tokenize(context: TokenizerContext, text: String): Array<String> = arrayOf(text)
5859
}
60+
61+
override fun healthcheck(): () -> NlpHealthcheckResult = {
62+
val clients = SagemakerClientProvider.getAllClient()
63+
if (clients.isEmpty()) {
64+
NlpHealthcheckResult(
65+
entityClassifier = false,
66+
intentClassifier = false
67+
)
68+
} else {
69+
val grouped = clients.groupBy { it.name }.withDefault { emptyList() }
70+
val entityClients = grouped.getValue(SagemakerEntityClassifier.CLIENT_TYPE.clientName)
71+
val intentClients = grouped.getValue(SagemakerIntentClassifier.CLIENT_TYPE.clientName)
72+
73+
NlpHealthcheckResult(
74+
entityClassifier = entityClients.isNotEmpty() && entityClients.all { it.healthcheck() },
75+
intentClassifier = intentClients.isNotEmpty() && intentClients.all { it.healthcheck() }
76+
)
77+
}
78+
}
5979
}
80+

nlp/model/sagemaker/src/main/kotlin/SagemakerEntityClassifier.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ import ai.tock.nlp.model.service.engine.NlpEntityClassifier
2626
import ai.tock.shared.property
2727
import software.amazon.awssdk.regions.Region
2828

29+
2930
internal class SagemakerEntityClassifier(model: EntityModelHolder) : NlpEntityClassifier(model) {
31+
companion object {
32+
val CLIENT_TYPE = SagemakerClientType.ENTITY_CLASSIFICATION
33+
}
3034

3135
override fun classifyEntities(
3236
context: EntityCallContext,
@@ -35,6 +39,7 @@ internal class SagemakerEntityClassifier(model: EntityModelHolder) : NlpEntityCl
3539
): List<EntityRecognition> {
3640
SagemakerClientProvider.getClient(
3741
SagemakerAwsClientProperties(
42+
CLIENT_TYPE.clientName,
3843
Region.of(property("tock_sagemaker_aws_region_name", "eu-west-3")),
3944
property("tock_sagemaker_aws_entities_endpoint_name", "default"),
4045
property("tock_sagemaker_aws_content_type", "application/json"),

nlp/model/sagemaker/src/main/kotlin/SagemakerIntentClassifier.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@
1515
*/
1616
package ai.tock.nlp.sagemaker
1717

18-
import ai.tock.nlp.sagemaker.SagemakerAwsClient.ParsedRequest
1918
import ai.tock.nlp.core.Intent
2019
import ai.tock.nlp.core.IntentClassification
2120
import ai.tock.nlp.model.IntentContext
2221
import ai.tock.nlp.model.service.engine.IntentClassifier
22+
import ai.tock.nlp.sagemaker.SagemakerAwsClient.ParsedRequest
2323
import ai.tock.shared.property
2424
import software.amazon.awssdk.regions.Region
2525

26+
2627
internal class SagemakerIntentClassifier(private val conf: SagemakerModelConfiguration) : IntentClassifier {
28+
companion object {
29+
val CLIENT_TYPE = SagemakerClientType.INTENT_CLASSIFICATION
30+
}
2731

2832
override fun classifyIntent(context: IntentContext, text: String, tokens: Array<String>): IntentClassification {
2933
return SagemakerClientProvider.getClient(
3034
SagemakerAwsClientProperties(
35+
CLIENT_TYPE.clientName,
3136
Region.of(property("tock_sagemaker_aws_region_name", "eu-west-3")),
3237
property("tock_sagemaker_aws_intent_endpoint_name", "default"),
3338
property("tock_sagemaker_aws_content_type", "application/json"),
@@ -36,7 +41,6 @@ internal class SagemakerIntentClassifier(private val conf: SagemakerModelConfigu
3641
).parseIntent(ParsedRequest(text))
3742
.run {
3843
object : IntentClassification {
39-
4044
var probability = 0.0
4145
val iterator = intent_ranking.iterator()
4246

0 commit comments

Comments
 (0)