Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
os: [ubuntu-latest]
java: [temurin@17]
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout current branch (full)
uses: actions/checkout@v4
Expand Down Expand Up @@ -64,6 +67,64 @@ jobs:
working-directory: ./spark-ui

- name: Build and test plugin
run: sbt +test
id: sbt-test
run: |
set -o pipefail
sbt +test 2>&1 | tee sbt-test.log
working-directory: ./spark-plugin

- name: Surface failing test output
if: failure() && steps.sbt-test.conclusion == 'failure'
run: |
LOG=spark-plugin/sbt-test.log
{
echo "## sbt test failure"
echo ""
echo "### \`[error]\` / failed test lines"
echo '```'
# Grab failure markers plus 3 lines of trailing context.
grep -nE '\*\*\* (FAILED|ABORTED) \*\*\*|^\[error\]|^\[info\] [A-Z][^:]*:$' "$LOG" \
| tail -200 || true
echo '```'
echo ""
echo "### Last 200 log lines"
echo '```'
tail -200 "$LOG"
echo '```'
} >> "$GITHUB_STEP_SUMMARY"

- name: Upload sbt test log
if: always()
uses: actions/upload-artifact@v4
with:
name: sbt-test-log
path: spark-plugin/sbt-test.log
if-no-files-found: ignore
retention-days: 14

- name: Post failure block as PR comment
if: failure() && steps.sbt-test.conclusion == 'failure' && github.event_name == 'pull_request'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SHA: ${{ github.event.pull_request.head.sha }}
run: |
LOG=spark-plugin/sbt-test.log
{
echo "## sbt test failure — \`${SHA:0:7}\`"
echo ""
echo "### \`[error]\` / failed test lines"
echo '```'
grep -nE '\*\*\* (FAILED|ABORTED) \*\*\*|^\[error\]|^\[info\] [A-Z][^:]*:$' "$LOG" \
| tail -200 || true
echo '```'
echo ""
echo "<details><summary>Last 300 lines of sbt output</summary>"
echo ""
echo '```'
tail -300 "$LOG"
echo '```'
echo "</details>"
} > /tmp/ci-failure-comment.md
gh pr comment "$PR_NUMBER" --body-file /tmp/ci-failure-comment.md

3 changes: 2 additions & 1 deletion spark-plugin/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ lazy val pluginspark4 = (project in file("pluginspark4"))
Test / unmanagedSources ++= {
val pluginspark3Tests = (pluginspark3 / Test / sourceDirectory).value / "scala"
Seq(
pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala"
pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala",
pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "ZDataflintIOListenerSpec.scala"
)
},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class DataflintSparkUICommonInstaller extends Logging {
val deltaLakeCacheZindexFieldsToProperties = context.conf.getBoolean("spark.dataflint.instrument.deltalake.cacheZindexFieldsToProperties", defaultValue = true)
val deltaLakeHistoryLimit = context.conf.getInt("spark.dataflint.instrument.deltalake.historyLimit", defaultValue = 1000)
val icebergAuthCatalogDiscovery = context.conf.getBoolean("spark.dataflint.iceberg.autoCatalogDiscovery", defaultValue = false)
val ioTrackingEnabled = context.conf.getBoolean(DataflintSparkUICommonLoader.IO_TRACKING_ENABLED, defaultValue = false)
if (ioTrackingEnabled) {
DataflintSparkUICommonLoader.registerIOListener(context)
}
if(icebergInstalled && icebergEnabled) {
if(icebergAuthCatalogDiscovery && isMetricLoaderInRightClassLoader()) {
context.conf.getAll.filter(_._1.startsWith("spark.sql.catalog")).filter(keyValue => keyValue._2 == "org.apache.iceberg.spark.SparkCatalog" || keyValue._2 == "org.apache.iceberg.spark.SparkSessionCatalog").foreach(keyValue => {
Expand Down Expand Up @@ -135,6 +139,8 @@ class DataflintSparkUICommonInstaller extends Logging {
object DataflintSparkUICommonLoader extends Logging {

private val DATAFLINT_EXTENSION_CLASS = "org.apache.spark.dataflint.DataFlintInstrumentationExtension"
private val DATAFLINT_IO_LISTENER_CLASS = "org.apache.spark.dataflint.listener.DataflintIOListener"
val IO_TRACKING_ENABLED = "spark.dataflint.io.tracking.enabled"
val INSTRUMENT_SPARK_ENABLED = "spark.dataflint.instrument.spark.enabled"
val INSTRUMENT_MAP_IN_PANDAS_ENABLED = "spark.dataflint.instrument.spark.mapInPandas.enabled"
val INSTRUMENT_MAP_IN_ARROW_ENABLED = "spark.dataflint.instrument.spark.mapInArrow.enabled"
Expand Down Expand Up @@ -205,4 +211,31 @@ object DataflintSparkUICommonLoader extends Logging {
logWarning("Could not register DataFlint instrumentation extension", e)
}
}

/**
* Append [[org.apache.spark.dataflint.listener.DataflintIOListener]] to
* `spark.sql.queryExecutionListeners` so Spark instantiates it when the
* SparkSession is built. The listener captures full read/write/save paths
* and table identifiers from the typed analyzed plan (bypassing
* SQLConf.maxToStringFields truncation of plan descriptions).
*
* This is in the org.apache.spark.dataflint package to access SparkContext.conf
* (which is private[spark]).
*/
def registerIOListener(sc: SparkContext): Unit = {
try {
val current = sc.conf.get("spark.sql.queryExecutionListeners", "")
if (current.contains(DATAFLINT_IO_LISTENER_CLASS)) {
logInfo("DataflintIOListener already registered in spark.sql.queryExecutionListeners")
} else {
val updated = if (current.isEmpty) DATAFLINT_IO_LISTENER_CLASS
else s"$current,$DATAFLINT_IO_LISTENER_CLASS"
sc.conf.set("spark.sql.queryExecutionListeners", updated)
logInfo("Registered DataflintIOListener in spark.sql.queryExecutionListeners")
}
} catch {
case e: Throwable =>
logWarning("Could not register DataflintIOListener", e)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package org.apache.spark.dataflint.listener

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand}
import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommandExec}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}

import scala.util.Try

/**
* Pulls IO targets (reads/writes/saves) out of a QueryExecution by reading
* typed plan fields — not regex over `simpleString` — so paths and identifiers
* are captured at full length regardless of SQLConf.maxToStringFields truncation.
*
* Primary walk is over `qe.executedPlan` with a pre-order index assigned to each
* node; that index matches the nodeId Spark's [[org.apache.spark.sql.execution.ui.SparkPlanGraph]]
* assigns, so consumers can correlate each target back to the SparkUI SQL tab.
* For [[DataWritingCommandExec]] we delegate to the logical-plan matcher on the
* wrapped `cmd` and overlay the physical nodeId.
*/
object DataflintIOExtractor extends Logging {

/**
* Walk the physical plan, attaching the pre-order nodeId to every captured target.
* A single node can yield multiple targets (e.g. a FileSourceScanExec with multiple
* rootPaths emits one target per path) — keeps each target small enough that one
* Spark event per target stays well under Databricks' 2 MB per-event cap.
*/
def extract(qe: QueryExecution): Seq[DataflintIOTarget] = {
val out = collection.mutable.ArrayBuffer.empty[DataflintIOTarget]
Try {
var idx = 0
qe.executedPlan.foreach { node =>
Try(extractFromPhysical(node, idx)).toOption.foreach(out ++= _)
idx += 1
}
}.recover {
case t: Throwable => logWarning("DataflintIOExtractor: failed to walk plan", t)
}
out.toSeq
}

private def extractFromPhysical(node: SparkPlan, nodeId: Int): Seq[DataflintIOTarget] = node match {

// --- PHYSICAL READS ------------------------------------------------------
case fsse: FileSourceScanExec =>
val rootPaths = fsse.relation.location.rootPaths.map(_.toString)
val base = DataflintIOTarget(
nodeId = nodeId,
operation = "read",
format = fsse.relation.fileFormat.getClass.getSimpleName,
paths = Seq.empty,
tableName = fsse.tableIdentifier.map(_.unquotedString),
saveMode = None,
partitionColumns = fsse.relation.partitionSchema.fieldNames.toSeq,
options = fsse.relation.options
)
if (rootPaths.isEmpty) Seq(base)
else rootPaths.map(p => base.copy(paths = Seq(p)))

case bse: BatchScanExec =>
Seq(DataflintIOTarget(
nodeId = nodeId,
operation = "read",
format = Option(bse.table).map(_.getClass.getSimpleName).getOrElse("DataSourceV2"),
paths = Seq.empty,
tableName = Option(bse.table).map(_.name),
saveMode = None,
partitionColumns = Seq.empty,
options = Map.empty
))

case rdse: RowDataSourceScanExec =>
Seq(DataflintIOTarget(
nodeId = nodeId,
operation = "read",
format = Try(rdse.relation.getClass.getSimpleName).getOrElse("RowDataSource"),
paths = Seq.empty,
tableName = Try(rdse.tableIdentifier.map(_.unquotedString)).getOrElse(None),
saveMode = None,
partitionColumns = Seq.empty,
options = Map.empty
))

// --- PHYSICAL WRITES (V1) — delegate to logical matcher, overlay nodeId --
case dwce: DataWritingCommandExec =>
extractFromLogical(dwce.cmd).map(_.copy(nodeId = nodeId)).toSeq

case _ => Seq.empty
}

/**
* Logical-plan matcher for the commands [[DataWritingCommandExec]] wraps.
* Returns nodeId = -1; the physical-plan caller overlays the real id.
*/
private def extractFromLogical(node: LogicalPlan): Option[DataflintIOTarget] = node match {

case lr: LogicalRelation =>
val (paths, format) = lr.relation match {
case h: HadoopFsRelation =>
(h.location.rootPaths.map(_.toString), h.fileFormat.getClass.getSimpleName)
case other =>
(Seq.empty[String], other.getClass.getSimpleName)
}
Some(DataflintIOTarget(
nodeId = -1,
operation = "read",
format = format,
paths = paths,
tableName = lr.catalogTable.map(_.identifier.unquotedString),
saveMode = None,
partitionColumns = lr.catalogTable.map(_.partitionColumnNames).getOrElse(Seq.empty),
options = Map.empty
))

case d: DataSourceV2Relation =>
val opts = safeOptionsMap(d.options.asCaseSensitiveMap())
Some(DataflintIOTarget(
nodeId = -1,
operation = "read",
format = Option(d.table).map(_.getClass.getSimpleName).getOrElse("DataSourceV2"),
paths = optionalPath(opts),
tableName = d.identifier.map(_.toString).orElse(Option(d.table).map(_.name)),
saveMode = None,
partitionColumns = Seq.empty,
options = opts
))

case c: InsertIntoHadoopFsRelationCommand =>
Some(DataflintIOTarget(
nodeId = -1,
operation = "write",
format = c.fileFormat.getClass.getSimpleName,
paths = Seq(c.outputPath.toString),
tableName = c.catalogTable.map(_.identifier.unquotedString),
saveMode = Some(c.mode.toString),
partitionColumns = c.partitionColumns.map(_.name),
options = c.options
))

case c: SaveIntoDataSourceCommand =>
Some(DataflintIOTarget(
nodeId = -1,
operation = "write",
format = c.dataSource.getClass.getSimpleName,
paths = c.options.get("path").toSeq,
tableName = c.options.get("dbtable").orElse(c.options.get("table")),
saveMode = Some(c.mode.toString),
partitionColumns = Seq.empty,
options = c.options
))

case c: CreateDataSourceTableAsSelectCommand =>
Some(DataflintIOTarget(
nodeId = -1,
operation = "write",
format = c.table.provider.getOrElse(""),
paths = c.table.storage.locationUri.map(_.toString).toSeq,
tableName = Some(c.table.identifier.unquotedString),
saveMode = Some(c.mode.toString),
partitionColumns = c.table.partitionColumnNames,
options = c.table.storage.properties
))

case w: V2WriteCommand =>
val named = Try(w.table).toOption
Some(DataflintIOTarget(
nodeId = -1,
operation = "write",
format = named.map(_.getClass.getSimpleName).getOrElse("V2Write"),
paths = Seq.empty,
tableName = named.flatMap(t => Option(t.name)),
saveMode = Some(w.getClass.getSimpleName),
partitionColumns = Seq.empty,
options = Map.empty
))

case _ => None
}

/** Defensive conversion of a Spark CaseInsensitiveStringMap to a Scala Map. */
private def safeOptionsMap(m: java.util.Map[String, String]): Map[String, String] = {
Try {
val it = m.entrySet().iterator()
val b = Map.newBuilder[String, String]
while (it.hasNext) { val e = it.next(); b += (e.getKey -> e.getValue) }
b.result()
}.getOrElse(Map.empty)
}

private def optionalPath(opts: Map[String, String]): Seq[String] =
opts.get("path").orElse(opts.get("paths")).toSeq
}
Loading
Loading