Skip to content

Commit 1ac7a2b

Browse files
authored
Merge pull request #1 from scribd/dmytrou/AWS_secrets_manager
PE-939: AWS Secrets Manager
2 parents e58813b + 709d4b3 commit 1ac7a2b

File tree

10 files changed

+206
-47
lines changed

10 files changed

+206
-47
lines changed

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % Provi
1212
libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % Provided
1313
libraryDependencies += "org.apache.spark" %% "spark-hive" % sparkVersion % Provided
1414
libraryDependencies += "com.databricks" % "dbutils-api_2.12" % "0.0.5" % Provided
15-
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595" % Provided
15+
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595"
16+
libraryDependencies += "com.amazonaws" % "aws-java-sdk-secretsmanager" % "1.11.595"
1617
libraryDependencies += "io.delta" % "delta-core_2.12" % "1.0.0" % Provided
1718
libraryDependencies += "org.scalaj" %% "scalaj-http" % "2.4.2"
1819
//libraryDependencies += "org.apache.hive" % "hive-metastore" % "2.3.9"

src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,25 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw
8888
override def deserialize(jp: JsonParser, ctxt: DeserializationContext): OverwatchParams = {
8989
val masterNode = jp.getCodec.readTree[JsonNode](jp)
9090

91-
val token = try {
92-
Some(TokenSecret(
93-
masterNode.get("tokenSecret").get("scope").asText(),
94-
masterNode.get("tokenSecret").get("key").asText()))
95-
} catch {
96-
case e: Throwable =>
97-
println("No Token Secret Defined", e)
98-
None
91+
// TODO: consider keeping enum with specific secrets inner structure and below
92+
// transform to function processing the enum in a loop
93+
val token = {
94+
95+
val databricksToken =
96+
for {
97+
scope <- getOptionString(masterNode,"tokenSecret.scope")
98+
key <- getOptionString(masterNode, "tokenSecret.key")
99+
} yield TokenSecret(scope, key)
100+
101+
val finalToken = if (databricksToken.isEmpty)
102+
for {
103+
secretId <- getOptionString(masterNode,"tokenSecret.secretId")
104+
region <- getOptionString(masterNode,"tokenSecret.region")
105+
apiToken <- getOptionString(masterNode,"tokenSecret.tokenKey")
106+
} yield AwsTokenSecret(secretId, region, apiToken)
107+
else databricksToken
108+
109+
finalToken
99110
}
100111

101112
val rawAuditPath = getOptionString(masterNode, "auditLogConfig.rawAuditPath")

src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class Initializer(config: Config) extends SparkSessionWrapper {
293293
config.setExternalizeOptimize(rawParams.externalizeOptimize)
294294

295295
val overwatchScope = rawParams.overwatchScope.getOrElse(Seq("all"))
296-
val tokenSecret = rawParams.tokenSecret
296+
297297
// TODO -- PRIORITY -- If data target is null -- default table gets dbfs:/null
298298
val dataTarget = rawParams.dataTarget.getOrElse(
299299
DataTarget(Some("overwatch"), Some("dbfs:/user/hive/warehouse/overwatch.db"), None))
@@ -302,24 +302,30 @@ class Initializer(config: Config) extends SparkSessionWrapper {
302302
if (overwatchScope.head == "all") config.setOverwatchScope(config.orderedOverwatchScope)
303303
else config.setOverwatchScope(validateScope(overwatchScope))
304304

305-
// validate token secret requirements
306-
// TODO - Validate if token has access to necessary assets. Warn/Fail if not
307-
if (tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
308-
if (tokenSecret.get.scope.isEmpty || tokenSecret.get.key.isEmpty) {
309-
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
310-
s"Either supply both or neither.")
305+
if (rawParams.tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
306+
rawParams.tokenSecret.map {
307+
case databricksSecret: TokenSecret =>
308+
// validate token secret requirements
309+
// TODO - Validate if databricks token has access to necessary assets. Warn/Fail if not
310+
311+
if (databricksSecret.scope.isEmpty || databricksSecret.key.isEmpty) {
312+
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
313+
s"Either supply both or neither.")
314+
}
315+
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == databricksSecret.scope)
316+
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${databricksSecret.scope} does not exist " +
317+
s"in this workspace. Please provide a scope available and accessible to this account.")
318+
val scopeName = scopeCheck.head
319+
320+
val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == databricksSecret.key)
321+
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${databricksSecret.key} does not exist " +
322+
s"within the provided scope: ${databricksSecret.scope}. Please provide a scope and key " +
323+
s"available and accessible to this account.")
324+
325+
config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))
326+
327+
case awsSecret: AwsTokenSecret => config.registerWorkspaceMeta(Some(awsSecret))
311328
}
312-
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == tokenSecret.get.scope)
313-
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${tokenSecret.get.scope} does not exist " +
314-
s"in this workspace. Please provide a scope available and accessible to this account.")
315-
val scopeName = scopeCheck.head
316-
317-
val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == tokenSecret.get.key)
318-
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${tokenSecret.get.key} does not exist " +
319-
s"within the provided scope: ${tokenSecret.get.scope}. Please provide a scope and key " +
320-
s"available and accessible to this account.")
321-
322-
config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))
323329
} else config.registerWorkspaceMeta(None)
324330

325331
// Validate data Target

src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.databricks.labs.overwatch.pipeline.TransformFunctions._
44
import com.databricks.labs.overwatch.utils._
55
import org.apache.log4j.{Level, Logger}
66
import org.apache.spark.sql.DataFrame
7+
import org.apache.spark.sql.functions._
78

89
class Module(
910
val moduleId: Int,
@@ -172,10 +173,27 @@ class Module(
172173
initState
173174
}
174175

176+
private def normalizeToken(secretToken: TokenSecret, reportDf: DataFrame): DataFrame = {
177+
val inputConfigCols = reportDf.select($"inputConfig.*")
178+
.columns
179+
.filter(_!="tokenSecret")
180+
.map(name => col("inputConfig."+name))
181+
182+
reportDf
183+
.withColumn(
184+
"inputConfig",
185+
struct(inputConfigCols:+struct(lit(secretToken.scope),lit(secretToken.key)).as("tokenSecret"):_*)
186+
)
187+
}
188+
175189
private def finalizeModule(report: ModuleStatusReport): Unit = {
176190
pipeline.updateModuleState(report.simple)
177191
if (!pipeline.readOnly) {
178-
pipeline.database.write(Seq(report).toDF, pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
192+
val secretToken = SecretTools(report.inputConfig.tokenSecret.get).getTargetTableStruct
193+
val targetDf = normalizeToken(secretToken, Seq(report).toDF)
194+
pipeline.database.write(
195+
targetDf,
196+
pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
179197
}
180198
}
181199

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.databricks.labs.overwatch.utils
2+
3+
import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder
4+
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest
5+
import org.apache.log4j.{Level, Logger}
6+
import org.json4s.DefaultFormats
7+
import org.json4s.jackson.JsonMethods.parse
8+
9+
import java.util.Base64
10+
11+
object AwsSecrets {
12+
private val logger: Logger = Logger.getLogger(this.getClass)
13+
14+
def readApiToken(secretId: String, region: String, apiTokenKey: String = "apiToken"): String = {
15+
secretValueAsMap(secretId, region)
16+
.getOrElse(apiTokenKey ,throw new IllegalStateException("apiTokenKey param not found"))
17+
.asInstanceOf[String]
18+
}
19+
20+
def secretValueAsMap(secretId: String, region: String = "us-east-2"): Map[String, Any] =
21+
parseJsonToMap(readRawSecretFromAws(secretId,region))
22+
23+
def readRawSecretFromAws(secretId: String, region: String): String = {
24+
logger.log(Level.INFO,s"Looking up secret $secretId in AWS Secret Manager")
25+
26+
val secretsClient = AWSSecretsManagerClientBuilder
27+
.standard()
28+
.withRegion(region)
29+
.build()
30+
val request = new GetSecretValueRequest().withSecretId(secretId)
31+
val secretValue = secretsClient.getSecretValue(request)
32+
33+
if (secretValue.getSecretString != null)
34+
secretValue.getSecretString
35+
else
36+
new String(Base64.getDecoder.decode(secretValue.getSecretBinary).array)
37+
}
38+
39+
def parseJsonToMap(jsonStr: String): Map[String, Any] = {
40+
implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
41+
parse(jsonStr).extract[Map[String, Any]]
42+
}
43+
}

src/main/scala/com/databricks/labs/overwatch/utils/Config.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,14 @@ class Config() {
321321
* as the job owner or notebook user (if called from notebook)
322322
* @return
323323
*/
324-
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecret]): this.type = {
324+
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecretContainer]): this.type = {
325325
var rawToken = ""
326-
var scope = ""
327-
var key = ""
328326
try {
329327
// Token secrets not supported in local testing
330328
if (tokenSecret.nonEmpty && !_isLocalTesting) { // not local testing and secret passed
331329
_workspaceUrl = dbutils.notebook.getContext().apiUrl.get
332330
_cloudProvider = if (_workspaceUrl.toLowerCase().contains("azure")) "azure" else "aws"
333-
scope = tokenSecret.get.scope
334-
key = tokenSecret.get.key
335-
rawToken = dbutils.secrets.get(scope, key)
336-
val authMessage = s"Valid Secret Identified: Executing with token located in secret, $scope : $key"
337-
logger.log(Level.INFO, authMessage)
331+
rawToken = SecretTools(tokenSecret.get).getApiToken
338332
_tokenType = "Secret"
339333
} else {
340334
if (_isLocalTesting) { // Local testing env vars
@@ -353,7 +347,7 @@ class Config() {
353347
}
354348
}
355349
if (!rawToken.matches("^(dapi|dkea)[a-zA-Z0-9-]*$")) throw new BadConfigException(s"contents of secret " +
356-
s"at scope:key $scope:$key is not in a valid format. Please validate the contents of your secret. It must be " +
350+
s"is not in a valid format. Please validate the contents of your secret. It must be " +
357351
s"a user access token. It should start with 'dapi' ")
358352
setApiEnv(ApiEnv(isLocalTesting, workspaceURL, rawToken, packageVersion))
359353
this
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.databricks.labs.overwatch.utils
2+
3+
import com.databricks.dbutils_v1.DBUtilsHolder.dbutils
4+
import org.apache.log4j.{Level, Logger}
5+
6+
/**
7+
* SecretTools handles common functionality related to working with secrets:
8+
* 1) Get Databricks API token stored in specified secret
9+
* 2) Normalize secret structure to be stored at Delta table pipeline_report under inputConfig.tokenSecret nested struct column
10+
* There are two secret types available now - AWS Secrets Manager, Databricks secrets
11+
*/
12+
trait SecretTools[T <: TokenSecretContainer] {
13+
def getApiToken : String
14+
def getTargetTableStruct: TokenSecret
15+
}
16+
17+
object SecretTools {
18+
private val logger: Logger = Logger.getLogger(this.getClass)
19+
type DatabricksTokenSecret = TokenSecret
20+
21+
private class DatabricksSecretTools(tokenSecret : DatabricksTokenSecret) extends SecretTools[DatabricksTokenSecret] {
22+
override def getApiToken: String = {
23+
val scope = tokenSecret.scope
24+
val key = tokenSecret.key
25+
val authMessage = s"Executing with Databricks token located in secret, scope=$scope : key=$key"
26+
logger.log(Level.INFO, authMessage)
27+
dbutils.secrets.get(scope, key)
28+
}
29+
30+
override def getTargetTableStruct: TokenSecret = {
31+
TokenSecret(tokenSecret.scope,tokenSecret.key)
32+
}
33+
}
34+
35+
private class AwsSecretTools(tokenSecret : AwsTokenSecret) extends SecretTools[AwsTokenSecret] {
36+
override def getApiToken: String = {
37+
val secretId = tokenSecret.secretId
38+
val region = tokenSecret.region
39+
val tokenKey = tokenSecret.tokenKey
40+
val authMessage = s"Executing with AWS token located in secret, secretId=$secretId : region=$region : tokenKey=$tokenKey"
41+
logger.log(Level.INFO, authMessage)
42+
AwsSecrets.readApiToken(secretId, region, tokenSecret.tokenKey)
43+
}
44+
45+
override def getTargetTableStruct: TokenSecret = {
46+
TokenSecret(tokenSecret.region, tokenSecret.secretId)
47+
}
48+
}
49+
50+
def apply(secretSource: TokenSecretContainer): SecretTools[_] = {
51+
secretSource match {
52+
case x: AwsTokenSecret => new AwsSecretTools(x)
53+
case y: DatabricksTokenSecret => new DatabricksSecretTools(y)
54+
case _ => throw new IllegalArgumentException(s"${secretSource.toString} not implemented")
55+
}
56+
}
57+
}

src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ case class SparkDetail()
2121

2222
case class GangliaDetail()
2323

24-
case class TokenSecret(scope: String, key: String)
24+
abstract class TokenSecretContainer extends Product with Serializable
25+
case class TokenSecret(scope: String, key: String) extends TokenSecretContainer
26+
case class AwsTokenSecret(secretId: String, region: String, tokenKey: String = "apiToken") extends TokenSecretContainer
2527

2628
case class DataTarget(databaseName: Option[String], databaseLocation: Option[String], etlDataPathPrefix: Option[String],
2729
consumerDatabaseName: Option[String] = None, consumerDatabaseLocation: Option[String] = None)
@@ -75,7 +77,7 @@ case class AuditLogConfig(
7577
case class IntelligentScaling(enabled: Boolean = false, minimumCores: Int = 4, maximumCores: Int = 512, coeff: Double = 1.0)
7678

7779
case class OverwatchParams(auditLogConfig: AuditLogConfig,
78-
tokenSecret: Option[TokenSecret] = None,
80+
tokenSecret: Option[TokenSecretContainer] = None,
7981
dataTarget: Option[DataTarget] = None,
8082
badRecordsPath: Option[String] = None,
8183
overwatchScope: Option[Seq[String]] = None,
@@ -356,9 +358,15 @@ object OverwatchEncoders {
356358
implicit def overwatchScope: org.apache.spark.sql.Encoder[OverwatchScope] =
357359
org.apache.spark.sql.Encoders.kryo[OverwatchScope]
358360

361+
/*
359362
implicit def tokenSecret: org.apache.spark.sql.Encoder[TokenSecret] =
360363
org.apache.spark.sql.Encoders.kryo[TokenSecret]
361364
365+
implicit def tokenSecretContainer: org.apache.spark.sql.Encoder[TokenSecretContainer] =
366+
org.apache.spark.sql.Encoders.kryo[TokenSecretContainer]
367+
368+
*/
369+
362370
implicit def dataTarget: org.apache.spark.sql.Encoder[DataTarget] =
363371
org.apache.spark.sql.Encoders.kryo[DataTarget]
364372

src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import com.fasterxml.jackson.module.scala.DefaultScalaModule
1111
import io.delta.tables.DeltaTable
1212
import org.apache.commons.lang3.StringEscapeUtils
1313
import org.apache.hadoop.conf._
14-
import org.apache.hadoop.fs.{FileSystem, Path}
14+
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
1515
import org.apache.log4j.{Level, Logger}
1616
import org.apache.spark.sql.functions._
1717
import org.apache.spark.util.SerializableConfiguration

src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@ class ParamDeserializerTest extends AnyFunSpec {
1212

1313
describe("ParamDeserializer") {
1414

15+
val paramModule: SimpleModule = new SimpleModule()
16+
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
17+
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
18+
.registerModule(DefaultScalaModule)
19+
.registerModule(paramModule)
20+
.asInstanceOf[ObjectMapper with ScalaObjectMapper]
21+
22+
it("should decode passed token string as AWS secrets") {
23+
val AWSsecrets = """
24+
|{"tokenSecret":{"secretId":"overwatch","region":"us-east-2","tokenKey":"apiToken"}}
25+
|""".stripMargin
26+
27+
28+
val expected = Some(AwsTokenSecret("overwatch", "us-east-2", "apiToken"))
29+
val parsed = mapper.readValue[OverwatchParams](AWSsecrets)
30+
assertResult(expected)(parsed.tokenSecret)
31+
}
32+
33+
it("should decode passed token string as Databricks secrets") {
34+
val Databrickssecrets = """
35+
|{"tokenSecret":{"scope":"overwatch", "key":"test-key"}}
36+
|""".stripMargin
37+
38+
val expected = Some(TokenSecret("overwatch", "test-key"))
39+
val parsed = mapper.readValue[OverwatchParams](Databrickssecrets)
40+
assertResult(expected)(parsed.tokenSecret)
41+
}
42+
1543
it("should decode incomplete parameters") {
1644
val incomplete = """
1745
|{"auditLogConfig":{"azureAuditLogEventhubConfig":{"connectionString":"test","eventHubName":"overwatch-evhub",
@@ -24,13 +52,6 @@ class ParamDeserializerTest extends AnyFunSpec {
2452
|"workspace_name":"myTestWorkspace", "externalizeOptimizations":"false"}
2553
|""".stripMargin
2654

27-
val paramModule: SimpleModule = new SimpleModule()
28-
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
29-
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
30-
.registerModule(DefaultScalaModule)
31-
.registerModule(paramModule)
32-
.asInstanceOf[ObjectMapper with ScalaObjectMapper]
33-
3455
val expected = OverwatchParams(
3556
AuditLogConfig(
3657
azureAuditLogEventhubConfig = Some(AzureAuditLogEventhubConfig(

0 commit comments

Comments
 (0)