diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c2d2a88..8897e7bb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,9 @@ jobs: os: [ubuntu-latest] java: [temurin@17] runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write steps: - name: Checkout current branch (full) uses: actions/checkout@v4 @@ -64,6 +67,64 @@ jobs: working-directory: ./spark-ui - name: Build and test plugin - run: sbt +test + id: sbt-test + run: | + set -o pipefail + sbt +test 2>&1 | tee sbt-test.log working-directory: ./spark-plugin + - name: Surface failing test output + if: failure() && steps.sbt-test.conclusion == 'failure' + run: | + LOG=spark-plugin/sbt-test.log + { + echo "## sbt test failure" + echo "" + echo "### \`[error]\` / failed test lines" + echo '```' + # Grab failure markers plus 3 lines of trailing context. + grep -nE '\*\*\* (FAILED|ABORTED) \*\*\*|^\[error\]|^\[info\] [A-Z][^:]*:$' "$LOG" \ + | tail -200 || true + echo '```' + echo "" + echo "### Last 200 log lines" + echo '```' + tail -200 "$LOG" + echo '```' + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload sbt test log + if: always() + uses: actions/upload-artifact@v4 + with: + name: sbt-test-log + path: spark-plugin/sbt-test.log + if-no-files-found: ignore + retention-days: 14 + + - name: Post failure block as PR comment + if: failure() && steps.sbt-test.conclusion == 'failure' && github.event_name == 'pull_request' + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + SHA: ${{ github.event.pull_request.head.sha }} + run: | + LOG=spark-plugin/sbt-test.log + { + echo "## sbt test failure — \`${SHA:0:7}\`" + echo "" + echo "### \`[error]\` / failed test lines" + echo '```' + grep -nE '\*\*\* (FAILED|ABORTED) \*\*\*|^\[error\]|^\[info\] [A-Z][^:]*:$' "$LOG" \ + | tail -200 || true + echo '```' + echo "" + echo "
Last 300 lines of sbt output" + echo "" + echo '```' + tail -300 "$LOG" + echo '```' + echo "
" + } > /tmp/ci-failure-comment.md + gh pr comment "$PR_NUMBER" --body-file /tmp/ci-failure-comment.md + diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 5715a85a..2481a312 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -183,7 +183,8 @@ lazy val pluginspark4 = (project in file("pluginspark4")) Test / unmanagedSources ++= { val pluginspark3Tests = (pluginspark3 / Test / sourceDirectory).value / "scala" Seq( - pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala" + pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala", + pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "ZDataflintIOListenerSpec.scala" ) }, diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/DataflintSparkUICommonLoader.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/DataflintSparkUICommonLoader.scala index 58cbdd23..867e2888 100644 --- a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/DataflintSparkUICommonLoader.scala +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/DataflintSparkUICommonLoader.scala @@ -53,6 +53,10 @@ class DataflintSparkUICommonInstaller extends Logging { val deltaLakeCacheZindexFieldsToProperties = context.conf.getBoolean("spark.dataflint.instrument.deltalake.cacheZindexFieldsToProperties", defaultValue = true) val deltaLakeHistoryLimit = context.conf.getInt("spark.dataflint.instrument.deltalake.historyLimit", defaultValue = 1000) val icebergAuthCatalogDiscovery = context.conf.getBoolean("spark.dataflint.iceberg.autoCatalogDiscovery", defaultValue = false) + val ioTrackingEnabled = context.conf.getBoolean(DataflintSparkUICommonLoader.IO_TRACKING_ENABLED, defaultValue = false) + if (ioTrackingEnabled) { + DataflintSparkUICommonLoader.registerIOListener(context) + } if(icebergInstalled && icebergEnabled) { if(icebergAuthCatalogDiscovery && isMetricLoaderInRightClassLoader()) { context.conf.getAll.filter(_._1.startsWith("spark.sql.catalog")).filter(keyValue => keyValue._2 == "org.apache.iceberg.spark.SparkCatalog" || keyValue._2 == "org.apache.iceberg.spark.SparkSessionCatalog").foreach(keyValue => { @@ -135,6 +139,8 @@ class DataflintSparkUICommonInstaller extends Logging { object DataflintSparkUICommonLoader extends Logging { private val DATAFLINT_EXTENSION_CLASS = "org.apache.spark.dataflint.DataFlintInstrumentationExtension" + private val DATAFLINT_IO_LISTENER_CLASS = "org.apache.spark.dataflint.listener.DataflintIOListener" + val IO_TRACKING_ENABLED = "spark.dataflint.io.tracking.enabled" val INSTRUMENT_SPARK_ENABLED = "spark.dataflint.instrument.spark.enabled" val INSTRUMENT_MAP_IN_PANDAS_ENABLED = "spark.dataflint.instrument.spark.mapInPandas.enabled" val INSTRUMENT_MAP_IN_ARROW_ENABLED = "spark.dataflint.instrument.spark.mapInArrow.enabled" @@ -205,4 +211,31 @@ object DataflintSparkUICommonLoader extends Logging { logWarning("Could not register DataFlint instrumentation extension", e) } } + + /** + * Append [[org.apache.spark.dataflint.listener.DataflintIOListener]] to + * `spark.sql.queryExecutionListeners` so Spark instantiates it when the + * SparkSession is built. The listener captures full read/write/save paths + * and table identifiers from the typed analyzed plan (bypassing + * SQLConf.maxToStringFields truncation of plan descriptions). + * + * This is in the org.apache.spark.dataflint package to access SparkContext.conf + * (which is private[spark]). + */ + def registerIOListener(sc: SparkContext): Unit = { + try { + val current = sc.conf.get("spark.sql.queryExecutionListeners", "") + if (current.contains(DATAFLINT_IO_LISTENER_CLASS)) { + logInfo("DataflintIOListener already registered in spark.sql.queryExecutionListeners") + } else { + val updated = if (current.isEmpty) DATAFLINT_IO_LISTENER_CLASS + else s"$current,$DATAFLINT_IO_LISTENER_CLASS" + sc.conf.set("spark.sql.queryExecutionListeners", updated) + logInfo("Registered DataflintIOListener in spark.sql.queryExecutionListeners") + } + } catch { + case e: Throwable => + logWarning("Could not register DataflintIOListener", e) + } + } } diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOExtractor.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOExtractor.scala new file mode 100644 index 00000000..fe4c6cb8 --- /dev/null +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOExtractor.scala @@ -0,0 +1,195 @@ +package org.apache.spark.dataflint.listener + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand} +import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommandExec} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} + +import scala.util.Try + +/** + * Pulls IO targets (reads/writes/saves) out of a QueryExecution by reading + * typed plan fields — not regex over `simpleString` — so paths and identifiers + * are captured at full length regardless of SQLConf.maxToStringFields truncation. + * + * Primary walk is over `qe.executedPlan` with a pre-order index assigned to each + * node; that index matches the nodeId Spark's [[org.apache.spark.sql.execution.ui.SparkPlanGraph]] + * assigns, so consumers can correlate each target back to the SparkUI SQL tab. + * For [[DataWritingCommandExec]] we delegate to the logical-plan matcher on the + * wrapped `cmd` and overlay the physical nodeId. + */ +object DataflintIOExtractor extends Logging { + + /** + * Walk the physical plan, attaching the pre-order nodeId to every captured target. + * A single node can yield multiple targets (e.g. a FileSourceScanExec with multiple + * rootPaths emits one target per path) — keeps each target small enough that one + * Spark event per target stays well under Databricks' 2 MB per-event cap. + */ + def extract(qe: QueryExecution): Seq[DataflintIOTarget] = { + val out = collection.mutable.ArrayBuffer.empty[DataflintIOTarget] + Try { + var idx = 0 + qe.executedPlan.foreach { node => + Try(extractFromPhysical(node, idx)).toOption.foreach(out ++= _) + idx += 1 + } + }.recover { + case t: Throwable => logWarning("DataflintIOExtractor: failed to walk plan", t) + } + out.toSeq + } + + private def extractFromPhysical(node: SparkPlan, nodeId: Int): Seq[DataflintIOTarget] = node match { + + // --- PHYSICAL READS ------------------------------------------------------ + case fsse: FileSourceScanExec => + val rootPaths = fsse.relation.location.rootPaths.map(_.toString) + val base = DataflintIOTarget( + nodeId = nodeId, + operation = "read", + format = fsse.relation.fileFormat.getClass.getSimpleName, + paths = Seq.empty, + tableName = fsse.tableIdentifier.map(_.unquotedString), + saveMode = None, + partitionColumns = fsse.relation.partitionSchema.fieldNames.toSeq, + options = fsse.relation.options + ) + if (rootPaths.isEmpty) Seq(base) + else rootPaths.map(p => base.copy(paths = Seq(p))) + + case bse: BatchScanExec => + Seq(DataflintIOTarget( + nodeId = nodeId, + operation = "read", + format = Option(bse.table).map(_.getClass.getSimpleName).getOrElse("DataSourceV2"), + paths = Seq.empty, + tableName = Option(bse.table).map(_.name), + saveMode = None, + partitionColumns = Seq.empty, + options = Map.empty + )) + + case rdse: RowDataSourceScanExec => + Seq(DataflintIOTarget( + nodeId = nodeId, + operation = "read", + format = Try(rdse.relation.getClass.getSimpleName).getOrElse("RowDataSource"), + paths = Seq.empty, + tableName = Try(rdse.tableIdentifier.map(_.unquotedString)).getOrElse(None), + saveMode = None, + partitionColumns = Seq.empty, + options = Map.empty + )) + + // --- PHYSICAL WRITES (V1) — delegate to logical matcher, overlay nodeId -- + case dwce: DataWritingCommandExec => + extractFromLogical(dwce.cmd).map(_.copy(nodeId = nodeId)).toSeq + + case _ => Seq.empty + } + + /** + * Logical-plan matcher for the commands [[DataWritingCommandExec]] wraps. + * Returns nodeId = -1; the physical-plan caller overlays the real id. + */ + private def extractFromLogical(node: LogicalPlan): Option[DataflintIOTarget] = node match { + + case lr: LogicalRelation => + val (paths, format) = lr.relation match { + case h: HadoopFsRelation => + (h.location.rootPaths.map(_.toString), h.fileFormat.getClass.getSimpleName) + case other => + (Seq.empty[String], other.getClass.getSimpleName) + } + Some(DataflintIOTarget( + nodeId = -1, + operation = "read", + format = format, + paths = paths, + tableName = lr.catalogTable.map(_.identifier.unquotedString), + saveMode = None, + partitionColumns = lr.catalogTable.map(_.partitionColumnNames).getOrElse(Seq.empty), + options = Map.empty + )) + + case d: DataSourceV2Relation => + val opts = safeOptionsMap(d.options.asCaseSensitiveMap()) + Some(DataflintIOTarget( + nodeId = -1, + operation = "read", + format = Option(d.table).map(_.getClass.getSimpleName).getOrElse("DataSourceV2"), + paths = optionalPath(opts), + tableName = d.identifier.map(_.toString).orElse(Option(d.table).map(_.name)), + saveMode = None, + partitionColumns = Seq.empty, + options = opts + )) + + case c: InsertIntoHadoopFsRelationCommand => + Some(DataflintIOTarget( + nodeId = -1, + operation = "write", + format = c.fileFormat.getClass.getSimpleName, + paths = Seq(c.outputPath.toString), + tableName = c.catalogTable.map(_.identifier.unquotedString), + saveMode = Some(c.mode.toString), + partitionColumns = c.partitionColumns.map(_.name), + options = c.options + )) + + case c: SaveIntoDataSourceCommand => + Some(DataflintIOTarget( + nodeId = -1, + operation = "write", + format = c.dataSource.getClass.getSimpleName, + paths = c.options.get("path").toSeq, + tableName = c.options.get("dbtable").orElse(c.options.get("table")), + saveMode = Some(c.mode.toString), + partitionColumns = Seq.empty, + options = c.options + )) + + case c: CreateDataSourceTableAsSelectCommand => + Some(DataflintIOTarget( + nodeId = -1, + operation = "write", + format = c.table.provider.getOrElse(""), + paths = c.table.storage.locationUri.map(_.toString).toSeq, + tableName = Some(c.table.identifier.unquotedString), + saveMode = Some(c.mode.toString), + partitionColumns = c.table.partitionColumnNames, + options = c.table.storage.properties + )) + + case w: V2WriteCommand => + val named = Try(w.table).toOption + Some(DataflintIOTarget( + nodeId = -1, + operation = "write", + format = named.map(_.getClass.getSimpleName).getOrElse("V2Write"), + paths = Seq.empty, + tableName = named.flatMap(t => Option(t.name)), + saveMode = Some(w.getClass.getSimpleName), + partitionColumns = Seq.empty, + options = Map.empty + )) + + case _ => None + } + + /** Defensive conversion of a Spark CaseInsensitiveStringMap to a Scala Map. */ + private def safeOptionsMap(m: java.util.Map[String, String]): Map[String, String] = { + Try { + val it = m.entrySet().iterator() + val b = Map.newBuilder[String, String] + while (it.hasNext) { val e = it.next(); b += (e.getKey -> e.getValue) } + b.result() + }.getOrElse(Map.empty) + } + + private def optionalPath(opts: Map[String, String]): Seq[String] = + opts.get("path").orElse(opts.get("paths")).toSeq +} diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOListener.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOListener.scala new file mode 100644 index 00000000..1e5578d0 --- /dev/null +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintIOListener.scala @@ -0,0 +1,101 @@ +package org.apache.spark.dataflint.listener + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + +import scala.util.Try + +/** + * QueryExecutionListener that walks the analyzed logical plan after every action + * and posts a [[DataflintIOEvent]] with the full (untruncated) paths and table + * identifiers extracted from typed plan nodes. + * + * Registered automatically by [[org.apache.spark.dataflint.DataflintSparkUICommonLoader]] + * via `spark.sql.queryExecutionListeners` when `spark.dataflint.io.tracking.enabled=true`. + * + * Must have a no-arg constructor: Spark instantiates listeners listed in + * `spark.sql.queryExecutionListeners` via `Class.getConstructor()` reflection. + */ +class DataflintIOListener extends QueryExecutionListener with Logging { + + /** + * Per-event size budget for the targets payload (rough upper bound of JSON bytes). + * Databricks drops any SparkListenerEvent larger than ~2 MB; we cap chunks at + * 1 MB so the JSON envelope + base fields still fit comfortably. + */ + private val maxChunkBytes = 1024 * 1024 + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + process(funcName, qe, durationNs) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + // Still try — the plan was constructed even if execution failed. + process(funcName, qe, 0L) + } + + private def process(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + Try { + val targets = DataflintIOExtractor.extract(qe) + if (targets.nonEmpty) { + val durationMs = durationNs / 1000000L + val bus = qe.sparkSession.sparkContext.listenerBus + var chunkIndex = 0 + chunkTargets(targets).foreach { chunk => + val info = DataflintIOInfo( + executionId = qe.id, + sqlFuncName = funcName, + durationMs = durationMs, + chunkIndex = chunkIndex, + targets = chunk + ) + bus.post(DataflintIOEvent(info)) + chunkIndex += 1 + } + } + }.recover { + case t: Throwable => logWarning("DataflintIOListener: failed to process query", t) + } + } + + /** + * Greedy pack: accumulate targets in order until the next would push the chunk + * over the budget, then flush. A single target larger than the budget gets its + * own chunk (we prefer attempting a too-large event over dropping data). + */ + private[listener] def chunkTargets(targets: Seq[DataflintIOTarget]): Seq[Seq[DataflintIOTarget]] = { + val out = collection.mutable.ArrayBuffer.empty[Seq[DataflintIOTarget]] + val current = collection.mutable.ArrayBuffer.empty[DataflintIOTarget] + var currentBytes = 0 + targets.foreach { t => + val s = estimateBytes(t) + if (current.nonEmpty && currentBytes + s > maxChunkBytes) { + // toList — eager immutable copy. ArrayBuffer.toSeq in Scala 2.12 returns + // `this`, so a later current.clear() would empty the seq we just stored. + out += current.toList + current.clear() + currentBytes = 0 + } + current += t + currentBytes += s + } + if (current.nonEmpty) out += current.toList + out.toList + } + + /** Rough upper bound for the JSON-serialized size of one target. */ + private[listener] def estimateBytes(t: DataflintIOTarget): Int = { + val strBytes = + t.format.length + + t.operation.length + + t.tableName.fold(0)(_.length) + + t.saveMode.fold(0)(_.length) + + t.paths.iterator.map(_.length).sum + + t.partitionColumns.iterator.map(_.length).sum + + t.options.iterator.map { case (k, v) => k.length + v.length }.sum + // 1.2x for JSON escaping/structure + 256 B per-target constant overhead. + (strBytes * 12 / 10) + 256 + } +} + diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintListener.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintListener.scala index 158ea192..91018df5 100644 --- a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintListener.scala +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/DataflintListener.scala @@ -26,6 +26,10 @@ class DataflintListener(store: ElementTrackingStore) extends SparkListener with val wrapper = new DataflintDeltaLakeScanInfoWrapper(e.scanInfo) store.write(wrapper) } + case e: DataflintIOEvent => { + val wrapper = new DataflintIOInfoWrapper(e.ioInfo) + store.write(wrapper) + } case _ => {} } } catch { diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/model.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/model.scala index 16494aae..8ff21ef8 100644 --- a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/model.scala +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/listener/model.scala @@ -91,3 +91,37 @@ class DataflintDeltaLakeScanInfoWrapper(val info: DataflintDeltaLakeScanInfo) { @JsonIgnore def id: String = s"${info.minExecutionId}_${info.tablePath.replaceAll(" ", "")}" } + +// Captured from the typed physical plan via QueryExecutionListener so paths/table +// identifiers are the full untruncated values (not the simpleString form Spark uses +// in plan descriptions, which gets cut at SQLConf.maxToStringFields). +// nodeId is the pre-order index in qe.executedPlan, matching SparkPlanGraph node IDs; +// -1 if extracted from a logical-plan node that has no direct physical correspondence. +case class DataflintIOTarget( + nodeId: Int, + operation: String, // "read" | "write" + format: String, // "parquet", "delta", "iceberg", "kafka", ... + paths: Seq[String], + tableName: Option[String], + saveMode: Option[String], // SaveMode.toString for writes + partitionColumns: Seq[String], + options: Map[String, String] + ) + +case class DataflintIOInfo( + executionId: Long, + sqlFuncName: String, + durationMs: Long, + chunkIndex: Int, + targets: Seq[DataflintIOTarget] + ) + +case class DataflintIOEvent(ioInfo: DataflintIOInfo) extends SparkListenerEvent + +class DataflintIOInfoWrapper(val info: DataflintIOInfo) { + // (executionId, chunkIndex) — one wrapper per emitted event. Chunks are sized + // by the listener to stay under Databricks' 2 MB per-event cap. + @KVIndex + @JsonIgnore + def id: String = s"${info.executionId}_${info.chunkIndex}" +} diff --git a/spark-plugin/plugin/src/test/scala/org/apache/spark/dataflint/listener/DataflintIOListenerChunkingSpec.scala b/spark-plugin/plugin/src/test/scala/org/apache/spark/dataflint/listener/DataflintIOListenerChunkingSpec.scala new file mode 100644 index 00000000..9d96fa8a --- /dev/null +++ b/spark-plugin/plugin/src/test/scala/org/apache/spark/dataflint/listener/DataflintIOListenerChunkingSpec.scala @@ -0,0 +1,67 @@ +package org.apache.spark.dataflint.listener + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +/** + * Pure unit test for the chunking logic — no SparkSession needed. Builds + * targets of known string size and asserts that DataflintIOListener.chunkTargets + * never drops a target, never produces an empty chunk, and respects the size + * budget within one over-the-budget item's worth of slack. + */ +class DataflintIOListenerChunkingSpec extends AnyFunSuite with Matchers { + + private val listener = new DataflintIOListener + + private def target(pathSize: Int, nodeId: Int = 0): DataflintIOTarget = + DataflintIOTarget( + nodeId = nodeId, + operation = "read", + format = "Parquet", + paths = Seq("a" * pathSize), + tableName = None, + saveMode = None, + partitionColumns = Seq.empty, + options = Map.empty + ) + + test("packs many small targets into a single chunk") { + val targets = (1 to 50).map(_ => target(64)) + val chunks = listener.chunkTargets(targets) + chunks should have size 1 + chunks.head should have size 50 + } + + test("splits when total estimated bytes exceed the budget") { + // ~600 KB per target — three should fit in one 1 MB chunk, four should not. + val targets = (1 to 10).map(_ => target(600 * 1024)) + val chunks = listener.chunkTargets(targets) + chunks.size should be > 1 + chunks.flatten should contain theSameElementsInOrderAs targets + // Each chunk respects the budget except possibly a single oversize item alone. + chunks.foreach { chunk => + val sum = chunk.map(listener.estimateBytes).sum + withClue(s"chunk size $sum, chunk len ${chunk.size}: ") { + (chunk.size == 1 || sum <= 1024 * 1024) shouldBe true + } + } + } + + test("an oversized target gets its own chunk rather than being dropped") { + val huge = target(3 * 1024 * 1024) // ~3 MB single target — bigger than budget + val chunks = listener.chunkTargets(Seq(huge)) + chunks should have size 1 + chunks.head should have size 1 + } + + test("preserves order of targets across chunks") { + val targets = (1 to 20).map(i => target(200 * 1024, nodeId = i)) + val chunks = listener.chunkTargets(targets) + val nodeIds = chunks.flatten.map(_.nodeId) + nodeIds should contain theSameElementsInOrderAs (1 to 20) + } + + test("empty input produces empty output") { + listener.chunkTargets(Seq.empty) shouldBe empty + } +} diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala index ee08da4a..4b68f829 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala @@ -71,6 +71,10 @@ class DataFlintWriteMetricsSpec extends AnyFunSuite with Matchers with BeforeAnd Seq(("a", 1), ("b", 2), ("c", 3), ("d", 4)).toDF("key", "value") .write.mode("overwrite").parquet(tempDir.getAbsolutePath) + // QueryExecutionListener delivery is async via the listener bus — drain + // it so the test's onSuccess hook has set capturedQE before we read it. + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + val qe = capturedQE.get() qe should not be null diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/ZDataflintIOListenerSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/ZDataflintIOListenerSpec.scala new file mode 100644 index 00000000..efba1197 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/ZDataflintIOListenerSpec.scala @@ -0,0 +1,108 @@ +package org.apache.spark.dataflint + +import java.util.concurrent.CopyOnWriteArrayList + +import org.apache.spark.dataflint.listener.{DataflintIOEvent, DataflintIOInfo, DataflintIOListener, DataflintIOTarget} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +/** + * Integration test: enable DataflintIOListener via spark.sql.queryExecutionListeners, + * run a write+read against a deeply-nested temp path, and verify that the listener + * captured the FULL path — i.e. nothing got cut by Spark's maxToStringFields-driven + * plan description truncation. + */ +class ZDataflintIOListenerSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { + + private var spark: SparkSession = _ + private val captured = new CopyOnWriteArrayList[DataflintIOInfo]() + + private val collector = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: DataflintIOEvent => captured.add(e.ioInfo) + case _ => + } + } + + override def beforeAll(): Unit = { + spark = SparkSession.builder() + .master("local[1]") + .appName("DataflintIOListenerSpec") + // Drive plan-description truncation hard, so a regression-test failure would + // be visible as a "..." path coming through. + .config("spark.sql.maxPlanStringLength", "32") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .config("spark.sql.queryExecutionListeners", classOf[DataflintIOListener].getName) + .getOrCreate() + spark.sparkContext.addSparkListener(collector) + } + + override def afterAll(): Unit = { + if (spark != null) { + spark.sparkContext.removeSparkListener(collector) + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + + private def deleteRecursively(f: java.io.File): Unit = { + if (f.isDirectory) Option(f.listFiles).foreach(_.foreach(deleteRecursively)) + f.delete() + } + + /** Drain the async listener bus so any pending DataflintIOEvent is delivered to the collector. */ + private def drain(): Unit = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + } + + private def find(op: String): Option[DataflintIOTarget] = { + val all = (0 until captured.size()).map(captured.get).flatMap(_.targets) + all.find(_.operation == op) + } + + test("captures full write path through InsertIntoHadoopFsRelationCommand without truncation") { + val baseDir = java.nio.file.Files.createTempDirectory("dataflint-io").toFile + // Deliberately deep+long path so it would be truncated in plan descriptions. + val deepPath = new java.io.File(baseDir, + "a-very-long-and-deeply-nested-folder-name-aaaaaaaaaaaa/" + + "another-long-segment-bbbbbbbbbbbbbbbbbbbb/" + + "final-leaf-cccccccccccccccccccccccccccccc") + deepPath.getParentFile.mkdirs() + try { + captured.clear() + + val session = spark + import session.implicits._ + Seq(("a", 1), ("b", 2)).toDF("key", "value") + .write.mode("overwrite").parquet(deepPath.getAbsolutePath) + + drain() + val writeTarget = find("write") + .getOrElse(fail(s"No write target captured. Got ${captured.size()} events.")) + + writeTarget.paths should have size 1 + // Crucial assertion: full path, no truncation marker. + writeTarget.paths.head should include (deepPath.getName) + writeTarget.paths.head should not include "..." + writeTarget.saveMode shouldBe Some("Overwrite") + // nodeId is the pre-order index in qe.executedPlan; the write command exec + // is the root or near-root, so it should be a small non-negative integer. + writeTarget.nodeId should be >= 0 + + // Now read it back and verify the read target carries the full path too. + spark.read.parquet(deepPath.getAbsolutePath).collect() + + drain() + val readTarget = find("read").getOrElse(fail("No read target captured.")) + readTarget.paths.exists(p => p.contains(deepPath.getName) && !p.endsWith("...")) shouldBe true + readTarget.nodeId should be >= 0 + } finally { + deleteRecursively(baseDir) + } + } +}