diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b2a2b7027394f..99bcc5b804c46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3684,6 +3684,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE = + buildConf("spark.sql.streaming.stateStore.fileChecksumThreadPoolSize") + .internal() + .doc("Number of threads used to read/write files and their corresponding checksum files " + + "concurrently. Set to 0 to disable the thread pool and run operations sequentially on " + + "the calling thread. WARNING: Reducing below the default value of 4 may have " + + "performance impact.") + .version("4.2.0") + .intConf + .checkValue(x => x >= 0, "Must be a non-negative integer (0 to disable thread pool)") + .createWithDefault(4) + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") .internal() @@ -7173,6 +7185,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def checkpointFileChecksumSkipCreationIfFileMissingChecksum: Boolean = getConf(STREAMING_CHECKPOINT_FILE_CHECKSUM_SKIP_CREATION_IF_FILE_MISSING_CHECKSUM) + def stateStoreFileChecksumThreadPoolSize: Int = + getConf(STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE) + def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) def useDeprecatedKafkaOffsetFetching: Boolean = getConf(USE_DEPRECATED_KAFKA_OFFSET_FETCHING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala index fcfad636ab776..d5fbc563a3aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{CheckedInputStream, CheckedOutputStream, CRC32C} -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} import scala.concurrent.duration.Duration import scala.io.Source @@ -127,7 +127,9 @@ case class ChecksumFile(path: Path) { * orphan checksum files. If using this, it is your responsibility * to clean up the potential orphan checksum files. * @param numThreads This is the number of threads to use for the thread pool, for reading/writing - * files. To avoid blocking, if the file manager instance is being used by a + * files. Must be a non-negative integer. Setting this to 0 disables the thread + * pool and runs all operations sequentially on the calling thread. + * To avoid blocking, if the file manager instance is being used by a * single thread, then you can set this to 2 (one thread for main file, another * for checksum file). * If file manager is shared by multiple threads, you can set it to @@ -150,14 +152,26 @@ class ChecksumCheckpointFileManager( val numThreads: Int, val skipCreationIfFileMissingChecksum: Boolean) extends CheckpointFileManager with Logging { - assert(numThreads % 2 == 0, "numThreads must be a multiple of 2, we need 1 for the main file" + - "and another for the checksum file") + assert(numThreads >= 0, "numThreads must be a non-negative integer") import ChecksumCheckpointFileManager._ // This allows us to concurrently read/write the main file and checksum file - private val threadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonFixedThreadPool(numThreads, s"${this.getClass.getSimpleName}-Thread")) + private val threadPoolOpt: Option[ExecutionContextExecutorService] = + if (numThreads == 0) None + else Some(ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonFixedThreadPool(numThreads, s"${this.getClass.getSimpleName}-Thread"))) + + // ExecutionContext used for I/O operations on ChecksumFSDataInputStream and + // ChecksumCancellableFSDataOutputStream: uses the thread pool when numThreads > 0, or + // runs operations synchronously on the calling thread when numThreads == 0. + private val executionContext: ExecutionContext = threadPoolOpt.getOrElse( + // This will execute the runnable synchronously on the calling thread + new ExecutionContext { + override def execute(runnable: Runnable): Unit = runnable.run() + override def reportFailure(cause: Throwable): Unit = throw cause + } + ) override def list(path: Path, filter: PathFilter): Array[FileStatus] = { underlyingFileMgr.list(path, filter) @@ -191,17 +205,17 @@ class ChecksumCheckpointFileManager( val mainFileFuture = Future { createFunc(path) - }(threadPool) + }(executionContext) val checksumFileFuture = Future { createFunc(getChecksumPath(path)) - }(threadPool) + }(executionContext) new ChecksumCancellableFSDataOutputStream( awaitResult(mainFileFuture, Duration.Inf), path, awaitResult(checksumFileFuture, Duration.Inf), - threadPool + executionContext ) } @@ -219,17 +233,17 @@ class ChecksumCheckpointFileManager( log"hence no checksum verification.") None } - }(threadPool) + }(executionContext) val mainInputStreamFuture = Future { underlyingFileMgr.open(path) - }(threadPool) + }(executionContext) val mainStream = awaitResult(mainInputStreamFuture, Duration.Inf) val checksumStream = awaitResult(checksumInputStreamFuture, Duration.Inf) checksumStream.map { chkStream => - new ChecksumFSDataInputStream(mainStream, path, chkStream, threadPool) + new ChecksumFSDataInputStream(mainStream, path, chkStream, executionContext) }.getOrElse(mainStream) } @@ -249,11 +263,11 @@ class ChecksumCheckpointFileManager( // if it happens. val checksumInputStreamFuture = Future { deleteChecksumFile(getChecksumPath(path)) - }(threadPool) + }(executionContext) val mainInputStreamFuture = Future { underlyingFileMgr.delete(path) - }(threadPool) + }(executionContext) awaitResult(mainInputStreamFuture, Duration.Inf) awaitResult(checksumInputStreamFuture, Duration.Inf) @@ -279,18 +293,20 @@ class ChecksumCheckpointFileManager( } override def close(): Unit = { - threadPool.shutdown() - // Wait a bit for it to finish up in case there is any ongoing work - // Can consider making this timeout configurable, if needed - val timeoutMs = 500 - if (!threadPool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) { - logWarning(log"Thread pool did not shutdown after ${MDC(TIMEOUT, timeoutMs)} ms," + - log" forcing shutdown") - threadPool.shutdownNow() // stop the executing tasks - - // Wait a bit for the threads to respond - if (!threadPool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) { - logError(log"Thread pool did not terminate") + threadPoolOpt.foreach { pool => + pool.shutdown() + // Wait a bit for it to finish up in case there is any ongoing work + // Can consider making this timeout configurable, if needed + val timeoutMs = 500 + if (!pool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) { + logWarning(log"Thread pool did not shutdown after ${MDC(TIMEOUT, timeoutMs)} ms," + + log" forcing shutdown") + pool.shutdownNow() // stop the executing tasks + + // Wait a bit for the threads to respond + if (!pool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) { + logError(log"Thread pool did not terminate") + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3303b414ccd35..33539d0d74b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -529,15 +529,19 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with private[state] lazy val fm = { val mgr = CheckpointFileManager.create(baseDir, hadoopConf) if (storeConf.checkpointFileChecksumEnabled) { + val threadPoolSize = storeConf.fileChecksumThreadPoolSize + if (threadPoolSize < 4) { + logWarning(s"fileChecksumThreadPoolSize is set to $threadPoolSize, which is below the " + + "recommended default of 4. This may have performance impact.") + } new ChecksumCheckpointFileManager( mgr, // Allowing this for perf, since we do orphan checksum file cleanup in maintenance anyway allowConcurrentDelete = true, - // We need 2 threads per fm caller to avoid blocking - // (one for main file and another for checksum file). - // Since this fm is used by both query task and maintenance thread, - // then we need 2 * 2 = 4 threads. - numThreads = 4, + // To avoid blocking, we need 2 threads per fm caller (one for main file, one for checksum + // file). Since this fm is used by both query task and maintenance thread, the recommended + // default is 2 * 2 = 4 threads. A value of 0 disables the thread pool (sequential mode). + numThreads = threadPoolSize, skipCreationIfFileMissingChecksum = storeConf.checkpointFileChecksumSkipCreationIfFileMissingChecksum) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 69a7e9618bb3d..4b2967aee7d3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -152,11 +152,17 @@ class RocksDB( private val workingDir = createTempDir("workingDir") - // We need 2 threads per fm caller to avoid blocking - // (one for main file and another for checksum file). - // Since this fm is used by both query task and maintenance thread, - // then we need 2 * 2 = 4 threads. - protected val fileChecksumThreadPoolSize: Option[Int] = Some(4) + // To avoid blocking, we need 2 threads per fm caller (one for main file, one for checksum file). + // Since this fm is used by both query task and maintenance thread, the recommended default is + // 2 * 2 = 4 threads. A value of 0 disables the thread pool (sequential execution). + protected val fileChecksumThreadPoolSize: Option[Int] = { + val size = conf.fileChecksumThreadPoolSize + if (size < 4) { + logWarning(s"fileChecksumThreadPoolSize is set to $size, which is below the " + + "recommended default of 4. This may have performance impact.") + } + Some(size) + } protected def createFileManager( dfsRootDir: String, @@ -2404,6 +2410,7 @@ case class RocksDBConf( reportSnapshotUploadLag: Boolean, maxVersionsToDeletePerMaintenance: Int, fileChecksumEnabled: Boolean, + fileChecksumThreadPoolSize: Int, rowChecksumEnabled: Boolean, rowChecksumReadVerificationRatio: Long, mergeOperatorVersion: Int, @@ -2619,6 +2626,7 @@ object RocksDBConf { storeConf.reportSnapshotUploadLag, storeConf.maxVersionsToDeletePerMaintenance, storeConf.checkpointFileChecksumEnabled, + storeConf.fileChecksumThreadPoolSize, storeConf.rowChecksumEnabled, storeConf.rowChecksumReadVerificationRatio, getPositiveIntConf(MERGE_OPERATOR_VERSION_CONF), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index f3bbc0ea24069..5a3e875541d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -116,6 +116,9 @@ class StateStoreConf( /** Whether file checksum generation and verification is enabled. */ val checkpointFileChecksumEnabled: Boolean = sqlConf.checkpointFileChecksumEnabled + /** Number of threads for the file checksum thread pool (0 to disable). */ + val fileChecksumThreadPoolSize: Int = sqlConf.stateStoreFileChecksumThreadPoolSize + /** whether to validate state schema during query run. */ val stateSchemaCheckEnabled = sqlConf.isStateSchemaCheckEnabled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ChecksumCheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ChecksumCheckpointFileManagerSuite.scala index b15e8f167db58..f283cfeecf335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ChecksumCheckpointFileManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ChecksumCheckpointFileManagerSuite.scala @@ -250,6 +250,39 @@ abstract class ChecksumCheckpointFileManagerSuite extends CheckpointFileManagerT checksumFmWithoutFallback.close() } } + + test("numThreads = 0 disables thread pool (sequential mode)") { + withTempHadoopPath { basePath => + val fm = new ChecksumCheckpointFileManager( + createNoChecksumManager(basePath), + allowConcurrentDelete = true, + numThreads = 0, + skipCreationIfFileMissingChecksum = false) + val path = new Path(basePath, "testfile") + val checksumPath = getChecksumPath(path) + // Write a file (main + checksum) in sequential mode + fm.createAtomic(path, overwriteIfPossible = false).writeContent(42).close() + // Verify both the main file and checksum file were written to disk + assert(fm.exists(path), "Main file should exist after write") + assert(fm.exists(checksumPath), "Checksum file should exist after write") + // Read it back - readContent() closes the stream, which triggers checksum verification + assert(fm.open(path).readContent() == 42) + fm.close() + } + } + + test("negative numThreads is invalid") { + withTempHadoopPath { basePath => + val ex = intercept[AssertionError] { + new ChecksumCheckpointFileManager( + createNoChecksumManager(basePath), + allowConcurrentDelete = true, + numThreads = -1, + skipCreationIfFileMissingChecksum = false) + } + assert(ex.getMessage.contains("numThreads must be a non-negative integer")) + } + } } class FileContextChecksumCheckpointFileManagerSuite extends ChecksumCheckpointFileManagerSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index d35263b655d2f..cb5d5b0a651e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -2465,6 +2465,107 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } + test("fileChecksumThreadPoolSize propagates to ChecksumCheckpointFileManager") { + Seq(0, 1, 6).foreach { numThreads => + val storeId = StateStoreId(newDir(), 0L, 0) + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> "true", + SQLConf.STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE.key -> numThreads.toString) { + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + val fmMethod = PrivateMethod[CheckpointFileManager](Symbol("fm")) + val fm = provider match { + case hdfs: HDFSBackedStateStoreProvider => + hdfs.fm + case rocksdb: RocksDBStateStoreProvider => + rocksdb.rocksDB.fileManager invokePrivate fmMethod() + case _ => + fail(s"Unexpected provider type: ${provider.getClass.getName}") + } + assert(fm.isInstanceOf[ChecksumCheckpointFileManager]) + assert(fm.asInstanceOf[ChecksumCheckpointFileManager].numThreads === numThreads) + } + } + } + } + + test("STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE: invalid negative value is rejected") { + val sqlConf = SQLConf.get.clone() + val ex = intercept[IllegalArgumentException] { + sqlConf.setConfString(SQLConf.STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE.key, "-1") + } + assert(ex.getMessage.contains("Must be a non-negative integer")) + } + + test("fileChecksumThreadPoolSize = 0 supports sequential I/O (load, write, commit, reload)") { + val storeId = StateStoreId(newDir(), 0L, 0) + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> "true", + SQLConf.STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE.key -> "0") { + // Write some state with sequential mode enabled + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + put(store, "b", 0, 2) + store.commit() + provider.doMaintenance() + } + + // Reload and verify state is intact + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + val store = provider.getStore(1) + assert(get(store, "a", 0) === Some(1)) + assert(get(store, "b", 0) === Some(2)) + store.abort() + } + } + } + + test("fileChecksumThreadPoolSize = 0: concurrent store commit and maintenance both complete") { + val storeId = StateStoreId(newDir(), 0L, 0) + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> "true", + SQLConf.STATE_STORE_FILE_CHECKSUM_THREAD_POOL_SIZE.key -> "0") { + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + // Build up a few versions so maintenance has something to work with + (0L until 3L).foreach { version => + putAndCommitStore(provider, version, doMaintenance = false) + } + + // Load the store and prepare the write before maintenance starts, so that + // store.commit() (the actual file I/O) is what overlaps with doMaintenance(). + val store = provider.getStore(3) + put(store, "3", 3, 300) + + val errors = new ConcurrentLinkedQueue[Throwable]() + val maintenanceStartedLatch = new CountDownLatch(1) + val maintenanceDoneLatch = new CountDownLatch(1) + + val maintenanceThread = new Thread(() => { + try { + maintenanceStartedLatch.countDown() + provider.doMaintenance() + } catch { + case t: Throwable => errors.add(t) + } finally { + maintenanceDoneLatch.countDown() + } + }) + maintenanceThread.setDaemon(true) + maintenanceThread.start() + + // Wait until maintenance is running, then commit to simulate concurrency. + assert(maintenanceStartedLatch.await(30, TimeUnit.SECONDS), + "Maintenance thread did not start within 30 seconds") + store.commit() + + assert(maintenanceDoneLatch.await(30, TimeUnit.SECONDS), + "Maintenance did not complete within 30 seconds") + assert(errors.isEmpty, + s"Maintenance failed with: ${Option(errors.peek()).map(_.getMessage).orNull}") + } + } + } + private def verifyChecksumFiles( dir: String, expectedNumFiles: Int, expectedNumChecksumFiles: Int): Unit = { val allFiles = new File(dir)