Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ lazy val migrator = (project in file("migrator"))
"com.github.jnr" % "jnr-posix" % "3.1.19", // Needed by the Spark ScyllaDB connector
"com.scylladb.alternator" % "emr-dynamodb-hadoop" % "5.8.0",
"com.scylladb.alternator" % "load-balancing" % "1.0.0",
"com.mysql" % "mysql-connector-j" % "8.3.0",
"io.circe" %% "circe-generic" % circeVersion,
"io.circe" %% "circe-parser" % circeVersion,
"io.circe" %% "circe-yaml" % "0.15.1",
Expand All @@ -73,9 +74,9 @@ lazy val migrator = (project in file("migrator"))
case "mime.types" => MergeStrategy.first
case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.concat
case PathList("META-INF", "versions", _, "module-info.class") =>
MergeStrategy.discard // OK as long as we dont rely on Java 9+ features such as SPI
MergeStrategy.discard // OK as long as we don't rely on Java 9+ features such as SPI
case "module-info.class" =>
MergeStrategy.discard // OK as long as we dont rely on Java 9+ features such as SPI
MergeStrategy.discard // OK as long as we don't rely on Java 9+ features such as SPI
case x =>
val oldStrategy = (assembly / assemblyMergeStrategy).value
oldStrategy(x)
Expand Down
4 changes: 4 additions & 0 deletions migrator/src/main/scala/com/scylladb/migrator/Migrator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ object Migrator {
migratorConfig.getSkipTokenRangesOrEmptySet
)
ScyllaMigrator.migrate(migratorConfig, scyllaTarget, sourceDF)
case (mysqlSource: SourceSettings.MySQL, scyllaTarget: TargetSettings.Scylla) =>
log.info("Starting MySQL to ScyllaDB migration")
val sourceDF = readers.MySQL.readDataframe(spark, mysqlSource)
ScyllaMigrator.migrate(migratorConfig, scyllaTarget, sourceDF)
case (parquetSource: SourceSettings.Parquet, scyllaTarget: TargetSettings.Scylla) =>
readers.Parquet.migrateToScylla(migratorConfig, parquetSource, scyllaTarget)(spark)
case (cqlSource: SourceSettings.Cassandra, parquetTarget: TargetSettings.Parquet) =>
Expand Down
32 changes: 30 additions & 2 deletions migrator/src/main/scala/com/scylladb/migrator/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.scylladb.migrator.config.{ MigratorConfig, SourceSettings, TargetSett
import com.scylladb.migrator.validation.RowComparisonFailure
import org.apache.log4j.{ Level, LogManager, Logger }
import org.apache.spark.sql.SparkSession
import com.scylladb.migrator.scylla.ScyllaValidator
import com.scylladb.migrator.scylla.{ MySQLToScyllaValidator, ScyllaValidator }

object Validator {
val log = LogManager.getLogger("com.scylladb.migrator")
Expand All @@ -18,6 +18,8 @@ object Validator {
ScyllaValidator.runValidation(cassandraSource, scyllaTarget, config)
case (dynamoSource: SourceSettings.DynamoDB, alternatorTarget: TargetSettings.DynamoDB) =>
AlternatorValidator.runValidation(dynamoSource, alternatorTarget, config)
case (mysqlSource: SourceSettings.MySQL, scyllaTarget: TargetSettings.Scylla) =>
MySQLToScyllaValidator.runValidation(mysqlSource, scyllaTarget, config)
case _ =>
sys.error(
"Unsupported combination of source and target " +
Expand Down Expand Up @@ -49,7 +51,33 @@ object Validator {

if (failures.isEmpty) log.info("No comparison failures found - enjoy your day!")
else {
log.error("Found the following comparison failures:")
val missingCount = failures.count(_.items.exists {
case RowComparisonFailure.Item.MissingTargetRow => true
case _ => false
})
val differingCount = failures.count(_.items.exists {
case _: RowComparisonFailure.Item.DifferingFieldValues => true
case _ => false
})
val mismatchedColumnCount = failures.count(_.items.exists {
case RowComparisonFailure.Item.MismatchedColumnCount => true
case _ => false
})
val mismatchedColumnNames = failures.count(_.items.exists {
case RowComparisonFailure.Item.MismatchedColumnNames => true
case _ => false
})

val breakdown = List(
if (missingCount > 0) Some(s"$missingCount missing target row(s)") else None,
if (differingCount > 0) Some(s"$differingCount differing field value(s)") else None,
if (mismatchedColumnCount > 0) Some(s"$mismatchedColumnCount mismatched column count(s)")
else None,
if (mismatchedColumnNames > 0) Some(s"$mismatchedColumnNames mismatched column name(s)")
else None
).flatten.mkString(", ")

log.error(s"Found ${failures.size} comparison failure(s): $breakdown")
log.error(failures.mkString("\n"))
System.exit(1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ object SourceSettings {
AwsUtils.computeFinalCredentials(credentials, endpoint, region)
}

case class MySQL(
host: String,
port: Int,
database: String,
table: String,
credentials: Credentials,
primaryKey: Option[List[String]],
partitionColumn: Option[String],
numPartitions: Option[Int],
lowerBound: Option[Long],
upperBound: Option[Long],
fetchSize: Int,
where: Option[String],
connectionProperties: Option[Map[String, String]]
) extends SourceSettings

case class DynamoDBS3Export(
bucket: String,
manifestKey: String,
Expand All @@ -72,7 +88,7 @@ object SourceSettings {

object DynamoDBS3Export {

/** Model the required fields of the TableCreationParameters object from the AWS API.
/** Model the required fields of the "TableCreationParameters" object from the AWS API.
* @see
* https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TableCreationParameters.html
*/
Expand Down Expand Up @@ -147,6 +163,8 @@ object SourceSettings {
deriveDecoder[DynamoDB].apply(cursor)
case "dynamodb-s3-export" =>
deriveDecoder[DynamoDBS3Export].apply(cursor)
case "mysql" =>
deriveDecoder[MySQL].apply(cursor)
case otherwise =>
Left(DecodingFailure(s"Unknown source type: ${otherwise}", cursor.history))
}
Expand All @@ -173,5 +191,10 @@ object SourceSettings {
.encodeObject(s)
.add("type", Json.fromString("dynamodb-s3-export"))
.asJson
case s: MySQL =>
deriveEncoder[MySQL]
.encodeObject(s)
.add("type", Json.fromString("mysql"))
.asJson
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
package com.scylladb.migrator.config

import io.circe.{ Decoder, Encoder }
import io.circe.generic.semiauto.{ deriveDecoder, deriveEncoder }
import io.circe.generic.extras.Configuration
import io.circe.generic.extras.semiauto._

/** @param compareTimestamps
* Whether to compare TTL and WRITETIME metadata (Cassandra only).
* @param ttlToleranceMillis
* Tolerance for TTL comparisons.
* @param writetimeToleranceMillis
* Tolerance for WRITETIME comparisons.
* @param failuresToFetch
* Maximum number of row failures to collect before stopping.
* @param floatingPointTolerance
* Tolerance for floating-point value comparisons.
* @param timestampMsTolerance
* Tolerance in milliseconds for timestamp comparisons.
* @param hashColumns
* When set, these columns are replaced by a single MD5 hash on the MySQL side and computed via
* Spark on the ScyllaDB side. This dramatically reduces data transfer for large text/blob
* columns. Only applies to MySQL-to-ScyllaDB validation.
*/
case class Validation(
compareTimestamps: Boolean,
ttlToleranceMillis: Long,
writetimeToleranceMillis: Long,
failuresToFetch: Int,
floatingPointTolerance: Double,
timestampMsTolerance: Long
timestampMsTolerance: Long,
hashColumns: Option[List[String]] = None
)
object Validation {
implicit val encoder: Encoder[Validation] = deriveEncoder[Validation]
implicit val decoder: Decoder[Validation] = deriveDecoder[Validation]
implicit val config: Configuration = Configuration.default.withDefaults
implicit val encoder: Encoder[Validation] = deriveConfiguredEncoder[Validation]
implicit val decoder: Decoder[Validation] = deriveConfiguredDecoder[Validation]
}
147 changes: 147 additions & 0 deletions migrator/src/main/scala/com/scylladb/migrator/readers/MySQL.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package com.scylladb.migrator.readers

import com.scylladb.migrator.config.SourceSettings
import com.scylladb.migrator.scylla.SourceDataFrame
import org.apache.log4j.LogManager
import org.apache.spark.sql.{ DataFrame, SparkSession }

object MySQL {
val log = LogManager.getLogger("com.scylladb.migrator.readers.MySQL")

val DefaultMaxAllowedPacketBytes: Long = 256 * 1024 * 1024 // 256MB
val ContentHashColumn: String = "_content_hash"

def readDataframe(spark: SparkSession, source: SourceSettings.MySQL): SourceDataFrame = {
val df = readDataframeRaw(spark, source, hashColumns = None)
val mapped = MySQLSchemaMapper.mapDataFrame(df)
log.info("MySQL mapped schema (after type transformations):")
mapped.printSchema()
log.info(s"Number of partitions: ${mapped.rdd.getNumPartitions}")
SourceDataFrame(mapped, timestampColumns = None, savepointsSupported = false)
}

/**
* Read from MySQL with hash-based optimization for validation. The specified columns
* are replaced by a single MD5 hash computed server-side in MySQL, dramatically reducing
* network data transfer. MySQL still reads the columns from disk for hashing, but only
* the 32-byte hash is sent over the wire.
*
* The resulting DataFrame contains all non-hashed columns plus a `_content_hash` column.
*/
def readDataframeWithHash(spark: SparkSession,
source: SourceSettings.MySQL,
hashColumns: List[String]): DataFrame = {
val df = readDataframeRaw(spark, source, hashColumns = Some(hashColumns))
log.info("MySQL hash-based schema:")
df.printSchema()
log.info(s"Number of partitions: ${df.rdd.getNumPartitions}")
df
}

private def buildJdbcUrl(source: SourceSettings.MySQL): String = {
val userProps = source.connectionProperties.getOrElse(Map.empty)
val maxPacket = userProps.getOrElse("maxAllowedPacket", DefaultMaxAllowedPacketBytes.toString)
s"jdbc:mysql://${source.host}:${source.port}/${source.database}" +
s"?zeroDateTimeBehavior=CONVERT_TO_NULL&tinyInt1isBit=false" +
s"&maxAllowedPacket=$maxPacket&useCursorFetch=true"
}

/**
* Discover column names by reading zero rows from the MySQL table.
*/
private def discoverColumns(spark: SparkSession, source: SourceSettings.MySQL): Array[String] = {
val url = buildJdbcUrl(source)
val schema = spark.read
.format("jdbc")
.option("url", url)
.option("user", source.credentials.username)
.option("password", source.credentials.password)
.option("driver", "com.mysql.cj.jdbc.Driver")
.option("dbtable", s"(SELECT * FROM `${source.table}` LIMIT 0) AS schema_probe")
.load()
.schema
schema.fieldNames
}

private def readDataframeRaw(spark: SparkSession,
source: SourceSettings.MySQL,
hashColumns: Option[List[String]]): DataFrame = {
val jdbcUrl = buildJdbcUrl(source)

log.info(s"Connecting to MySQL at ${source.host}:${source.port}/${source.database}")
log.info(s"Reading table: ${source.table}")
log.info(s"JDBC useCursorFetch=true, fetchSize=${source.fetchSize}")

var reader = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("user", source.credentials.username)
.option("password", source.credentials.password)
.option("driver", "com.mysql.cj.jdbc.Driver")
.option("fetchsize", source.fetchSize)

source.connectionProperties.getOrElse(Map.empty).foreach {
case (k, v) => reader = reader.option(k, v)
}

val tableExpression = hashColumns match {
case Some(cols) if cols.nonEmpty =>
val allColumnNames = discoverColumns(spark, source)
val hashColSet = cols.map(_.toLowerCase).toSet
val nonHashedCols = allColumnNames.filterNot(c => hashColSet.contains(c.toLowerCase))
val selectList = nonHashedCols.map(c => s"`$c`").mkString(", ")
val hashExpr = buildMySQLHashExpression(cols)
val whereClause = source.where.map(f => s" WHERE $f").getOrElse("")
val subquery =
s"(SELECT $selectList, $hashExpr FROM `${source.table}`$whereClause) AS hash_source"
log.info(s"Using hash-based read. Hashed columns: ${cols.mkString(", ")}")
log.info(s"Non-hashed columns: ${nonHashedCols.mkString(", ")}")
log.info(s"Subquery: $subquery")
subquery

case _ =>
source.where match {
case Some(filter) =>
log.info(s"Applying WHERE filter: $filter")
s"(SELECT * FROM `${source.table}` WHERE $filter) AS filtered_table"
case None => s"`${source.table}`"
}
}
reader = reader.option("dbtable", tableExpression)

(source.partitionColumn, source.numPartitions) match {
case (Some(col), Some(n)) =>
log.info(
s"Using partitioned read: column=$col, partitions=$n, " +
s"lowerBound=${source.lowerBound.getOrElse(0L)}, " +
s"upperBound=${source.upperBound.getOrElse(Long.MaxValue)}")
reader = reader
.option("partitionColumn", col)
.option("numPartitions", n)
.option("lowerBound", source.lowerBound.getOrElse(0L))
.option("upperBound", source.upperBound.getOrElse(Long.MaxValue))
case (Some(col), None) =>
sys.error(
s"partitionColumn '$col' specified but numPartitions is missing. " +
"Both partitionColumn and numPartitions must be set together for partitioned reads.")
case (None, Some(n)) =>
sys.error(
s"numPartitions ($n) specified but partitionColumn is missing. " +
"Both partitionColumn and numPartitions must be set together for partitioned reads.")
case _ =>
log.warn(
"No partitioning configured. This will read the entire table in a single partition. " +
"For large tables (2TB+), set partitionColumn and numPartitions for parallel reads.")
}

val rawDf = reader.load()
log.info("MySQL raw source schema:")
rawDf.printSchema()
rawDf
}

private def buildMySQLHashExpression(columns: List[String]): String = {
val coalesced = columns.map(c => s"COALESCE(`$c`, '')").mkString(", ")
s"MD5(CONCAT_WS('|', $coalesced)) AS `$ContentHashColumn`"
}
}
Loading
Loading