[-o output.py] [-d DIALECT]");
+ System.err.println();
+ System.err.println("Options:");
+ System.err.println(" -o FILE Write generated PySpark code to FILE (default: stdout)");
+ System.err.println(" -d DIALECT Preferred SQL dialect: ANSI_SQL, SNOWFLAKE, DATABRICKS (default: ANSI_SQL)");
+ System.exit(1);
+ }
+
+ String inputFile = args[0];
+ String outputFile = null;
+ String dialect = "ANSI_SQL";
+
+ for (int i = 1; i < args.length; i++) {
+ switch (args[i]) {
+ case "-o":
+ if (i + 1 < args.length) {
+ outputFile = args[++i];
+ }
+ break;
+ case "-d":
+ if (i + 1 < args.length) {
+ dialect = args[++i];
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ // Parse the OSI model
+ OsiModelParser parser = new OsiModelParser();
+ OsiModel model = parser.parse(Paths.get(inputFile));
+
+ if (model.getSemanticModels().isEmpty()) {
+ System.err.println("Error: no semantic_model found in " + inputFile);
+ System.exit(1);
+ }
+
+ // Generate PySpark code
+ SparkCodeGenerator generator = new SparkCodeGenerator(dialect);
+ String code = generator.generate(model);
+
+ if (outputFile != null) {
+ Files.write(Paths.get(outputFile), code.getBytes(StandardCharsets.UTF_8));
+ System.out.println("Generated PySpark code written to " + outputFile);
+ } else {
+ System.out.println(code);
+ }
+ }
+}
diff --git a/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java b/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java
new file mode 100644
index 0000000..6725509
--- /dev/null
+++ b/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java
@@ -0,0 +1,246 @@
+package org.osi.converter.spark;
+
+import org.osi.converter.spark.model.OsiModel;
+import org.osi.converter.spark.model.OsiModel.*;
+
+import java.util.List;
+
+/**
+ * Generates PySpark code from a parsed {@link OsiModel}.
+ *
+ * The generated code creates DataFrames, registers temp views,
+ * builds join helpers from relationships, and exposes metric functions.
+ */
+public class SparkCodeGenerator {
+
+ private final String dialect;
+
+ public SparkCodeGenerator() {
+ this("ANSI_SQL");
+ }
+
+ public SparkCodeGenerator(String dialect) {
+ this.dialect = dialect;
+ }
+
+ /**
+ * Generate the full PySpark module for the given OSI model.
+ */
+ public String generate(OsiModel model) {
+ StringBuilder sb = new StringBuilder();
+
+ for (SemanticModel sm : model.getSemanticModels()) {
+ generateHeader(sb, sm);
+ generateDatasetLoaders(sb, sm);
+ generateLoadAll(sb, sm);
+ generateJoinHelpers(sb, sm);
+ generateMetrics(sb, sm);
+ generateMain(sb, sm);
+ }
+
+ return sb.toString();
+ }
+
+ // -----------------------------------------------------------------------
+ // Header
+ // -----------------------------------------------------------------------
+
+ private void generateHeader(StringBuilder sb, SemanticModel sm) {
+ sb.append("\"\"\"\n");
+ sb.append("Auto-generated PySpark code from OSI semantic model: ").append(sm.getName()).append("\n");
+ if (sm.getDescription() != null) {
+ sb.append("\n").append(sm.getDescription()).append("\n");
+ }
+ sb.append("\"\"\"\n\n");
+ sb.append("from pyspark.sql import SparkSession, DataFrame\n");
+ sb.append("from pyspark.sql import functions as F\n\n\n");
+
+ sb.append("def get_spark() -> SparkSession:\n");
+ sb.append(" \"\"\"Return or create a SparkSession.\"\"\"\n");
+ sb.append(" return (\n");
+ sb.append(" SparkSession.builder\n");
+ sb.append(" .appName(\"").append(sm.getName()).append("\")\n");
+ sb.append(" .getOrCreate()\n");
+ sb.append(" )\n\n");
+ }
+
+ // -----------------------------------------------------------------------
+ // Dataset loaders
+ // -----------------------------------------------------------------------
+
+ private void generateDatasetLoaders(StringBuilder sb, SemanticModel sm) {
+ for (Dataset ds : sm.getDatasets()) {
+ generateDatasetLoader(sb, ds);
+ }
+ }
+
+ private void generateDatasetLoader(StringBuilder sb, Dataset ds) {
+ String funcName = "load_" + ds.getName();
+ sb.append("\ndef ").append(funcName).append("(spark: SparkSession) -> DataFrame:\n");
+ sb.append(" \"\"\"\n");
+ sb.append(" Load dataset: ").append(ds.getName()).append("\n");
+ if (ds.getDescription() != null) {
+ sb.append(" ").append(ds.getDescription()).append("\n");
+ }
+ sb.append(" Source: ").append(ds.getSource()).append("\n");
+ sb.append(" \"\"\"\n");
+
+ sb.append(" df = spark.table(\"").append(ds.getSource()).append("\")\n");
+
+ // Add computed columns for fields whose expression differs from their name
+ for (Field field : ds.getFields()) {
+ String expr = pickExpression(field.getExpressions());
+ if (expr != null && !expr.equals(field.getName())) {
+ sb.append(" df = df.withColumn(\"").append(field.getName())
+ .append("\", F.expr(\"").append(escapeString(expr)).append("\"))\n");
+ }
+ }
+
+ sb.append(" df.createOrReplaceTempView(\"").append(ds.getName()).append("\")\n");
+ sb.append(" return df\n\n");
+ }
+
+ // -----------------------------------------------------------------------
+ // Load all
+ // -----------------------------------------------------------------------
+
+ private void generateLoadAll(StringBuilder sb, SemanticModel sm) {
+ sb.append("\ndef load_all_datasets(spark: SparkSession) -> dict[str, DataFrame]:\n");
+ sb.append(" \"\"\"Load all datasets and register temp views. Returns a dict of name -> DataFrame.\"\"\"\n");
+ sb.append(" datasets = {}\n");
+ for (Dataset ds : sm.getDatasets()) {
+ sb.append(" datasets[\"").append(ds.getName()).append("\"] = load_")
+ .append(ds.getName()).append("(spark)\n");
+ }
+ sb.append(" return datasets\n\n");
+ }
+
+ // -----------------------------------------------------------------------
+ // Join helpers
+ // -----------------------------------------------------------------------
+
+ private void generateJoinHelpers(StringBuilder sb, SemanticModel sm) {
+ for (Relationship rel : sm.getRelationships()) {
+ generateJoinHelper(sb, rel);
+ }
+ }
+
+ private void generateJoinHelper(StringBuilder sb, Relationship rel) {
+ String fromDs = rel.getFrom();
+ String toDs = rel.getTo();
+ String funcName = "join_" + rel.getName();
+
+ sb.append("\ndef ").append(funcName).append("(")
+ .append(fromDs).append("_df: DataFrame, ")
+ .append(toDs).append("_df: DataFrame) -> DataFrame:\n");
+ sb.append(" \"\"\"\n");
+ sb.append(" Join ").append(fromDs).append(" -> ").append(toDs).append("\n");
+ sb.append(" \"\"\"\n");
+
+ List fromCols = rel.getFromColumns();
+ List toCols = rel.getToColumns();
+
+ if (fromCols.size() == 1) {
+ sb.append(" return ").append(fromDs).append("_df.join(")
+ .append(toDs).append("_df, ")
+ .append(fromDs).append("_df[\"").append(fromCols.get(0)).append("\"] == ")
+ .append(toDs).append("_df[\"").append(toCols.get(0)).append("\"], \"inner\")\n\n");
+ } else {
+ String condition = buildCompositeCondition(fromDs, toDs, fromCols, toCols);
+ sb.append(" condition = ").append(condition).append("\n");
+ sb.append(" return ").append(fromDs).append("_df.join(")
+ .append(toDs).append("_df, condition, \"inner\")\n\n");
+ }
+ }
+
+ private String buildCompositeCondition(String fromDs, String toDs,
+ List fromCols, List toCols) {
+ StringBuilder condition = new StringBuilder();
+ for (int i = 0; i < fromCols.size(); i++) {
+ if (i > 0) {
+ condition.append(" & ");
+ }
+ condition.append(fromDs).append("_df[\"").append(fromCols.get(i)).append("\"] == ")
+ .append(toDs).append("_df[\"").append(toCols.get(i)).append("\"]");
+ }
+ return condition.toString();
+ }
+
+ // -----------------------------------------------------------------------
+ // Metrics
+ // -----------------------------------------------------------------------
+
+ private void generateMetrics(StringBuilder sb, SemanticModel sm) {
+ for (Metric metric : sm.getMetrics()) {
+ generateMetric(sb, metric);
+ }
+ }
+
+ private void generateMetric(StringBuilder sb, Metric metric) {
+ String expr = pickExpression(metric.getExpressions());
+ if (expr == null) {
+ return;
+ }
+
+ String funcName = "compute_" + metric.getName();
+ sb.append("\ndef ").append(funcName).append("(spark: SparkSession) -> DataFrame:\n");
+ sb.append(" \"\"\"\n");
+ sb.append(" Metric: ").append(metric.getName()).append("\n");
+ if (metric.getDescription() != null) {
+ sb.append(" ").append(metric.getDescription()).append("\n");
+ }
+ sb.append(" Expression: ").append(expr).append("\n");
+ sb.append(" \"\"\"\n");
+ sb.append(" return spark.sql(\"SELECT ").append(escapeString(expr))
+ .append(" AS ").append(metric.getName()).append("\")\n\n");
+ }
+
+ // -----------------------------------------------------------------------
+ // Main block
+ // -----------------------------------------------------------------------
+
+ private void generateMain(StringBuilder sb, SemanticModel sm) {
+ sb.append("\nif __name__ == \"__main__\":\n");
+ sb.append(" spark = get_spark()\n");
+ sb.append(" print(f\"Spark session started: {spark.sparkContext.appName}\")\n\n");
+
+ sb.append(" # Load all datasets\n");
+ sb.append(" dfs = load_all_datasets(spark)\n");
+ sb.append(" for name, df in dfs.items():\n");
+ sb.append(" print(f\"Dataset {name}: {df.count()} rows\")\n\n");
+
+ if (!sm.getRelationships().isEmpty()) {
+ Relationship rel = sm.getRelationships().get(0);
+ sb.append(" # Example join: ").append(rel.getName()).append("\n");
+ sb.append(" joined = join_").append(rel.getName())
+ .append("(dfs[\"").append(rel.getFrom())
+ .append("\"], dfs[\"").append(rel.getTo()).append("\"])\n");
+ sb.append(" joined.show(5)\n\n");
+ }
+
+ sb.append(" spark.stop()\n");
+ }
+
+ // -----------------------------------------------------------------------
+ // Utilities
+ // -----------------------------------------------------------------------
+
+ /**
+ * Pick the expression for the preferred dialect, falling back to the first available.
+ */
+ private String pickExpression(List expressions) {
+ if (expressions == null || expressions.isEmpty()) {
+ return null;
+ }
+ for (DialectExpression de : expressions) {
+ if (dialect.equals(de.getDialect())) {
+ return de.getExpression();
+ }
+ }
+ return expressions.get(0).getExpression();
+ }
+
+ private String escapeString(String s) {
+ return s.replace("\\", "\\\\").replace("\"", "\\\"");
+ }
+}
diff --git a/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java b/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java
new file mode 100644
index 0000000..a205663
--- /dev/null
+++ b/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java
@@ -0,0 +1,155 @@
+package org.osi.converter.spark.model;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Java representation of an OSI semantic model parsed from YAML.
+ */
+public class OsiModel {
+
+ private String version;
+ private List semanticModels = new ArrayList<>();
+
+ public String getVersion() {
+ return version;
+ }
+
+ public void setVersion(String version) {
+ this.version = version;
+ }
+
+ public List getSemanticModels() {
+ return semanticModels;
+ }
+
+ public void setSemanticModels(List semanticModels) {
+ this.semanticModels = semanticModels;
+ }
+
+ // -----------------------------------------------------------------------
+ // Nested model classes
+ // -----------------------------------------------------------------------
+
+ public static class SemanticModel {
+ private String name;
+ private String description;
+ private List datasets = new ArrayList<>();
+ private List relationships = new ArrayList<>();
+ private List metrics = new ArrayList<>();
+
+ public String getName() { return name; }
+ public void setName(String name) { this.name = name; }
+
+ public String getDescription() { return description; }
+ public void setDescription(String description) { this.description = description; }
+
+ public List getDatasets() { return datasets; }
+ public void setDatasets(List datasets) { this.datasets = datasets; }
+
+ public List getRelationships() { return relationships; }
+ public void setRelationships(List relationships) { this.relationships = relationships; }
+
+ public List getMetrics() { return metrics; }
+ public void setMetrics(List metrics) { this.metrics = metrics; }
+ }
+
+ public static class Dataset {
+ private String name;
+ private String source;
+ private List primaryKey = new ArrayList<>();
+ private String description;
+ private List fields = new ArrayList<>();
+
+ public String getName() { return name; }
+ public void setName(String name) { this.name = name; }
+
+ public String getSource() { return source; }
+ public void setSource(String source) { this.source = source; }
+
+ public List getPrimaryKey() { return primaryKey; }
+ public void setPrimaryKey(List primaryKey) { this.primaryKey = primaryKey; }
+
+ public String getDescription() { return description; }
+ public void setDescription(String description) { this.description = description; }
+
+ public List getFields() { return fields; }
+ public void setFields(List fields) { this.fields = fields; }
+ }
+
+ public static class Field {
+ private String name;
+ private String description;
+ private List expressions = new ArrayList<>();
+ private boolean isTime;
+
+ public String getName() { return name; }
+ public void setName(String name) { this.name = name; }
+
+ public String getDescription() { return description; }
+ public void setDescription(String description) { this.description = description; }
+
+ public List getExpressions() { return expressions; }
+ public void setExpressions(List expressions) { this.expressions = expressions; }
+
+ public boolean isTime() { return isTime; }
+ public void setTime(boolean time) { isTime = time; }
+ }
+
+ public static class DialectExpression {
+ private String dialect;
+ private String expression;
+
+ public DialectExpression() {}
+
+ public DialectExpression(String dialect, String expression) {
+ this.dialect = dialect;
+ this.expression = expression;
+ }
+
+ public String getDialect() { return dialect; }
+ public void setDialect(String dialect) { this.dialect = dialect; }
+
+ public String getExpression() { return expression; }
+ public void setExpression(String expression) { this.expression = expression; }
+ }
+
+ public static class Relationship {
+ private String name;
+ private String from;
+ private String to;
+ private List fromColumns = new ArrayList<>();
+ private List toColumns = new ArrayList<>();
+
+ public String getName() { return name; }
+ public void setName(String name) { this.name = name; }
+
+ public String getFrom() { return from; }
+ public void setFrom(String from) { this.from = from; }
+
+ public String getTo() { return to; }
+ public void setTo(String to) { this.to = to; }
+
+ public List getFromColumns() { return fromColumns; }
+ public void setFromColumns(List fromColumns) { this.fromColumns = fromColumns; }
+
+ public List getToColumns() { return toColumns; }
+ public void setToColumns(List toColumns) { this.toColumns = toColumns; }
+ }
+
+ public static class Metric {
+ private String name;
+ private String description;
+ private List expressions = new ArrayList<>();
+
+ public String getName() { return name; }
+ public void setName(String name) { this.name = name; }
+
+ public String getDescription() { return description; }
+ public void setDescription(String description) { this.description = description; }
+
+ public List getExpressions() { return expressions; }
+ public void setExpressions(List expressions) { this.expressions = expressions; }
+ }
+}
diff --git a/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java b/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java
new file mode 100644
index 0000000..ecd7898
--- /dev/null
+++ b/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java
@@ -0,0 +1,152 @@
+package org.osi.converter.spark;
+
+import org.junit.jupiter.api.Test;
+import org.osi.converter.spark.model.OsiModel;
+
+import java.io.ByteArrayInputStream;
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+class OsiSparkConverterTest {
+
+ private static final String MINIMAL_MODEL =
+ "version: \"0.1.1\"\n"
+ + "\n"
+ + "semantic_model:\n"
+ + " - name: test_model\n"
+ + " description: A test model\n"
+ + " datasets:\n"
+ + " - name: orders\n"
+ + " source: db.public.orders\n"
+ + " primary_key: [order_id]\n"
+ + " description: Order fact table\n"
+ + " fields:\n"
+ + " - name: order_id\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: order_id\n"
+ + " - name: total_amount\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: quantity * unit_price\n"
+ + " description: Computed total\n"
+ + " - name: customer\n"
+ + " source: db.public.customer\n"
+ + " primary_key: [customer_id]\n"
+ + " fields:\n"
+ + " - name: customer_id\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: customer_id\n"
+ + " - name: full_name\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: \"first_name || ' ' || last_name\"\n"
+ + " relationships:\n"
+ + " - name: orders_to_customer\n"
+ + " from: orders\n"
+ + " to: customer\n"
+ + " from_columns: [customer_id]\n"
+ + " to_columns: [customer_id]\n"
+ + " metrics:\n"
+ + " - name: total_revenue\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: SUM(orders.total_amount)\n"
+ + " description: Total revenue across all orders\n";
+
+ @Test
+ void testParseMinimalModel() {
+ OsiModelParser parser = new OsiModelParser();
+ OsiModel model = parser.parse(
+ new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8)));
+
+ assertEquals("0.1.1", model.getVersion());
+ assertEquals(1, model.getSemanticModels().size());
+
+ OsiModel.SemanticModel sm = model.getSemanticModels().get(0);
+ assertEquals("test_model", sm.getName());
+ assertEquals(2, sm.getDatasets().size());
+ assertEquals(1, sm.getRelationships().size());
+ assertEquals(1, sm.getMetrics().size());
+ }
+
+ @Test
+ void testParseDatasetFields() {
+ OsiModelParser parser = new OsiModelParser();
+ OsiModel model = parser.parse(
+ new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8)));
+
+ OsiModel.Dataset orders = model.getSemanticModels().get(0).getDatasets().get(0);
+ assertEquals("orders", orders.getName());
+ assertEquals("db.public.orders", orders.getSource());
+ assertEquals(2, orders.getFields().size());
+
+ OsiModel.Field computed = orders.getFields().get(1);
+ assertEquals("total_amount", computed.getName());
+ assertEquals("quantity * unit_price", computed.getExpressions().get(0).getExpression());
+ }
+
+ @Test
+ void testGenerateContainsExpectedFunctions() {
+ OsiModelParser parser = new OsiModelParser();
+ OsiModel model = parser.parse(
+ new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8)));
+
+ SparkCodeGenerator generator = new SparkCodeGenerator("ANSI_SQL");
+ String code = generator.generate(model);
+
+ // Dataset loaders
+ assertTrue(code.contains("def load_orders(spark: SparkSession)"));
+ assertTrue(code.contains("def load_customer(spark: SparkSession)"));
+ assertTrue(code.contains("def load_all_datasets(spark: SparkSession)"));
+
+ // Computed columns
+ assertTrue(code.contains("df.withColumn(\"total_amount\", F.expr(\"quantity * unit_price\"))"));
+ assertTrue(code.contains("df.withColumn(\"full_name\", F.expr(\"first_name || ' ' || last_name\"))"));
+
+ // Temp views
+ assertTrue(code.contains("df.createOrReplaceTempView(\"orders\")"));
+
+ // Join helper
+ assertTrue(code.contains("def join_orders_to_customer("));
+
+ // Metric
+ assertTrue(code.contains("def compute_total_revenue(spark: SparkSession)"));
+ assertTrue(code.contains("SUM(orders.total_amount) AS total_revenue"));
+ }
+
+ @Test
+ void testGenerateWithDatabricksDialect() {
+ String multiDialect =
+ "version: \"0.1.1\"\n"
+ + "semantic_model:\n"
+ + " - name: multi\n"
+ + " datasets:\n"
+ + " - name: sales\n"
+ + " source: catalog.schema.sales\n"
+ + " fields:\n"
+ + " - name: amount\n"
+ + " expression:\n"
+ + " dialects:\n"
+ + " - dialect: ANSI_SQL\n"
+ + " expression: amount\n"
+ + " - dialect: DATABRICKS\n"
+ + " expression: \"CAST(amount AS DECIMAL(18,2))\"\n";
+
+ OsiModelParser parser = new OsiModelParser();
+ OsiModel model = parser.parse(
+ new ByteArrayInputStream(multiDialect.getBytes(StandardCharsets.UTF_8)));
+
+ SparkCodeGenerator generator = new SparkCodeGenerator("DATABRICKS");
+ String code = generator.generate(model);
+
+ assertTrue(code.contains("CAST(amount AS DECIMAL(18,2))"));
+ }
+}