From bdfc27e8ba451301e748c65a752def8b96ca75d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?JB=20Onofr=C3=A9?= Date: Mon, 23 Mar 2026 13:31:13 +0100 Subject: [PATCH] Add Apache Spark converter module Introduces a Java/Maven converter under converters/spark/ that parses OSI semantic model YAML files and generates PySpark code including dataset loaders, join helpers, and metric functions. --- converters/spark/pom.xml | 65 +++++ .../osi/converter/spark/OsiModelParser.java | 180 +++++++++++++ .../converter/spark/OsiSparkConverter.java | 71 +++++ .../converter/spark/SparkCodeGenerator.java | 246 ++++++++++++++++++ .../osi/converter/spark/model/OsiModel.java | 155 +++++++++++ .../spark/OsiSparkConverterTest.java | 152 +++++++++++ 6 files changed, 869 insertions(+) create mode 100644 converters/spark/pom.xml create mode 100644 converters/spark/src/main/java/org/osi/converter/spark/OsiModelParser.java create mode 100644 converters/spark/src/main/java/org/osi/converter/spark/OsiSparkConverter.java create mode 100644 converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java create mode 100644 converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java create mode 100644 converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java diff --git a/converters/spark/pom.xml b/converters/spark/pom.xml new file mode 100644 index 0000000..d95d3fc --- /dev/null +++ b/converters/spark/pom.xml @@ -0,0 +1,65 @@ + + + 4.0.0 + + org.osi + osi-spark-converter + 0.1.0-SNAPSHOT + jar + + OSI Spark Converter + Converts OSI semantic models to Apache Spark code + + + 11 + 11 + UTF-8 + 3.5.1 + 2.2 + 5.10.2 + + + + + + org.yaml + snakeyaml + ${snakeyaml.version} + + + + + org.apache.spark + spark-sql_2.13 + ${spark.version} + provided + + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + + org.osi.converter.spark.OsiSparkConverter + + + + + + + diff --git a/converters/spark/src/main/java/org/osi/converter/spark/OsiModelParser.java b/converters/spark/src/main/java/org/osi/converter/spark/OsiModelParser.java new file mode 100644 index 0000000..bc214c5 --- /dev/null +++ b/converters/spark/src/main/java/org/osi/converter/spark/OsiModelParser.java @@ -0,0 +1,180 @@ +package org.osi.converter.spark; + +import org.osi.converter.spark.model.OsiModel; +import org.osi.converter.spark.model.OsiModel.*; +import org.yaml.snakeyaml.Yaml; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + + +/** + * Parses an OSI YAML file into an {@link OsiModel}. + */ +public class OsiModelParser { + + /** + * Parse an OSI YAML file from the given path. + */ + public OsiModel parse(Path yamlPath) throws IOException { + try (InputStream is = Files.newInputStream(yamlPath)) { + return parse(is); + } + } + + /** + * Parse an OSI YAML file from an input stream. + */ + @SuppressWarnings("unchecked") + public OsiModel parse(InputStream is) { + Yaml yaml = new Yaml(); + Map root = yaml.load(is); + + OsiModel model = new OsiModel(); + model.setVersion((String) root.get("version")); + + List> smList = (List>) root.get("semantic_model"); + if (smList == null) { + return model; + } + + List semanticModels = new ArrayList<>(); + for (Map smMap : smList) { + semanticModels.add(parseSemanticModel(smMap)); + } + model.setSemanticModels(semanticModels); + return model; + } + + @SuppressWarnings("unchecked") + private SemanticModel parseSemanticModel(Map map) { + SemanticModel sm = new SemanticModel(); + sm.setName((String) map.get("name")); + sm.setDescription((String) map.get("description")); + + // Datasets + List> dsList = (List>) map.get("datasets"); + if (dsList != null) { + List datasets = new ArrayList<>(); + for (Map dsMap : dsList) { + datasets.add(parseDataset(dsMap)); + } + sm.setDatasets(datasets); + } + + // Relationships + List> relList = (List>) map.get("relationships"); + if (relList != null) { + List relationships = new ArrayList<>(); + for (Map relMap : relList) { + relationships.add(parseRelationship(relMap)); + } + sm.setRelationships(relationships); + } + + // Metrics + List> metricList = (List>) map.get("metrics"); + if (metricList != null) { + List metrics = new ArrayList<>(); + for (Map mMap : metricList) { + metrics.add(parseMetric(mMap)); + } + sm.setMetrics(metrics); + } + + return sm; + } + + @SuppressWarnings("unchecked") + private Dataset parseDataset(Map map) { + Dataset ds = new Dataset(); + ds.setName((String) map.get("name")); + ds.setSource((String) map.get("source")); + ds.setDescription((String) map.get("description")); + + List pk = (List) map.get("primary_key"); + if (pk != null) { + ds.setPrimaryKey(new ArrayList<>(pk)); + } + + List> fieldList = (List>) map.get("fields"); + if (fieldList != null) { + List fields = new ArrayList<>(); + for (Map fMap : fieldList) { + fields.add(parseField(fMap)); + } + ds.setFields(fields); + } + return ds; + } + + @SuppressWarnings("unchecked") + private Field parseField(Map map) { + Field field = new Field(); + field.setName((String) map.get("name")); + field.setDescription((String) map.get("description")); + + // Dimension + Map dim = (Map) map.get("dimension"); + if (dim != null) { + Object isTime = dim.get("is_time"); + field.setTime(Boolean.TRUE.equals(isTime)); + } + + // Expressions + field.setExpressions(parseDialectExpressions(map)); + return field; + } + + @SuppressWarnings("unchecked") + private Relationship parseRelationship(Map map) { + Relationship rel = new Relationship(); + rel.setName((String) map.get("name")); + rel.setFrom((String) map.get("from")); + rel.setTo((String) map.get("to")); + + List fromCols = (List) map.get("from_columns"); + if (fromCols != null) { + rel.setFromColumns(new ArrayList<>(fromCols)); + } + List toCols = (List) map.get("to_columns"); + if (toCols != null) { + rel.setToColumns(new ArrayList<>(toCols)); + } + return rel; + } + + @SuppressWarnings("unchecked") + private Metric parseMetric(Map map) { + Metric metric = new Metric(); + metric.setName((String) map.get("name")); + metric.setDescription((String) map.get("description")); + metric.setExpressions(parseDialectExpressions(map)); + return metric; + } + + @SuppressWarnings("unchecked") + private List parseDialectExpressions(Map map) { + List result = new ArrayList<>(); + Map exprBlock = (Map) map.get("expression"); + if (exprBlock == null) { + return result; + } + List> dialects = (List>) exprBlock.get("dialects"); + if (dialects == null) { + return result; + } + for (Map d : dialects) { + String dialect = (String) d.get("dialect"); + Object exprValue = d.get("expression"); + String expression = exprValue != null ? exprValue.toString() : null; + result.add(new DialectExpression(dialect, expression)); + } + return result; + } +} diff --git a/converters/spark/src/main/java/org/osi/converter/spark/OsiSparkConverter.java b/converters/spark/src/main/java/org/osi/converter/spark/OsiSparkConverter.java new file mode 100644 index 0000000..c8ac86e --- /dev/null +++ b/converters/spark/src/main/java/org/osi/converter/spark/OsiSparkConverter.java @@ -0,0 +1,71 @@ +package org.osi.converter.spark; + +import org.osi.converter.spark.model.OsiModel; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; + +/** + * CLI entry point: reads an OSI YAML semantic model and generates PySpark code. + * + *
+ * Usage:
+ *   java -jar osi-spark-converter.jar <osi_model.yaml> [-o output.py] [-d DIALECT]
+ * 
+ */ +public class OsiSparkConverter { + + public static void main(String[] args) throws IOException { + if (args.length < 1) { + System.err.println("Usage: osi-spark-converter [-o output.py] [-d DIALECT]"); + System.err.println(); + System.err.println("Options:"); + System.err.println(" -o FILE Write generated PySpark code to FILE (default: stdout)"); + System.err.println(" -d DIALECT Preferred SQL dialect: ANSI_SQL, SNOWFLAKE, DATABRICKS (default: ANSI_SQL)"); + System.exit(1); + } + + String inputFile = args[0]; + String outputFile = null; + String dialect = "ANSI_SQL"; + + for (int i = 1; i < args.length; i++) { + switch (args[i]) { + case "-o": + if (i + 1 < args.length) { + outputFile = args[++i]; + } + break; + case "-d": + if (i + 1 < args.length) { + dialect = args[++i]; + } + break; + default: + break; + } + } + + // Parse the OSI model + OsiModelParser parser = new OsiModelParser(); + OsiModel model = parser.parse(Paths.get(inputFile)); + + if (model.getSemanticModels().isEmpty()) { + System.err.println("Error: no semantic_model found in " + inputFile); + System.exit(1); + } + + // Generate PySpark code + SparkCodeGenerator generator = new SparkCodeGenerator(dialect); + String code = generator.generate(model); + + if (outputFile != null) { + Files.write(Paths.get(outputFile), code.getBytes(StandardCharsets.UTF_8)); + System.out.println("Generated PySpark code written to " + outputFile); + } else { + System.out.println(code); + } + } +} diff --git a/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java b/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java new file mode 100644 index 0000000..6725509 --- /dev/null +++ b/converters/spark/src/main/java/org/osi/converter/spark/SparkCodeGenerator.java @@ -0,0 +1,246 @@ +package org.osi.converter.spark; + +import org.osi.converter.spark.model.OsiModel; +import org.osi.converter.spark.model.OsiModel.*; + +import java.util.List; + +/** + * Generates PySpark code from a parsed {@link OsiModel}. + *

+ * The generated code creates DataFrames, registers temp views, + * builds join helpers from relationships, and exposes metric functions. + */ +public class SparkCodeGenerator { + + private final String dialect; + + public SparkCodeGenerator() { + this("ANSI_SQL"); + } + + public SparkCodeGenerator(String dialect) { + this.dialect = dialect; + } + + /** + * Generate the full PySpark module for the given OSI model. + */ + public String generate(OsiModel model) { + StringBuilder sb = new StringBuilder(); + + for (SemanticModel sm : model.getSemanticModels()) { + generateHeader(sb, sm); + generateDatasetLoaders(sb, sm); + generateLoadAll(sb, sm); + generateJoinHelpers(sb, sm); + generateMetrics(sb, sm); + generateMain(sb, sm); + } + + return sb.toString(); + } + + // ----------------------------------------------------------------------- + // Header + // ----------------------------------------------------------------------- + + private void generateHeader(StringBuilder sb, SemanticModel sm) { + sb.append("\"\"\"\n"); + sb.append("Auto-generated PySpark code from OSI semantic model: ").append(sm.getName()).append("\n"); + if (sm.getDescription() != null) { + sb.append("\n").append(sm.getDescription()).append("\n"); + } + sb.append("\"\"\"\n\n"); + sb.append("from pyspark.sql import SparkSession, DataFrame\n"); + sb.append("from pyspark.sql import functions as F\n\n\n"); + + sb.append("def get_spark() -> SparkSession:\n"); + sb.append(" \"\"\"Return or create a SparkSession.\"\"\"\n"); + sb.append(" return (\n"); + sb.append(" SparkSession.builder\n"); + sb.append(" .appName(\"").append(sm.getName()).append("\")\n"); + sb.append(" .getOrCreate()\n"); + sb.append(" )\n\n"); + } + + // ----------------------------------------------------------------------- + // Dataset loaders + // ----------------------------------------------------------------------- + + private void generateDatasetLoaders(StringBuilder sb, SemanticModel sm) { + for (Dataset ds : sm.getDatasets()) { + generateDatasetLoader(sb, ds); + } + } + + private void generateDatasetLoader(StringBuilder sb, Dataset ds) { + String funcName = "load_" + ds.getName(); + sb.append("\ndef ").append(funcName).append("(spark: SparkSession) -> DataFrame:\n"); + sb.append(" \"\"\"\n"); + sb.append(" Load dataset: ").append(ds.getName()).append("\n"); + if (ds.getDescription() != null) { + sb.append(" ").append(ds.getDescription()).append("\n"); + } + sb.append(" Source: ").append(ds.getSource()).append("\n"); + sb.append(" \"\"\"\n"); + + sb.append(" df = spark.table(\"").append(ds.getSource()).append("\")\n"); + + // Add computed columns for fields whose expression differs from their name + for (Field field : ds.getFields()) { + String expr = pickExpression(field.getExpressions()); + if (expr != null && !expr.equals(field.getName())) { + sb.append(" df = df.withColumn(\"").append(field.getName()) + .append("\", F.expr(\"").append(escapeString(expr)).append("\"))\n"); + } + } + + sb.append(" df.createOrReplaceTempView(\"").append(ds.getName()).append("\")\n"); + sb.append(" return df\n\n"); + } + + // ----------------------------------------------------------------------- + // Load all + // ----------------------------------------------------------------------- + + private void generateLoadAll(StringBuilder sb, SemanticModel sm) { + sb.append("\ndef load_all_datasets(spark: SparkSession) -> dict[str, DataFrame]:\n"); + sb.append(" \"\"\"Load all datasets and register temp views. Returns a dict of name -> DataFrame.\"\"\"\n"); + sb.append(" datasets = {}\n"); + for (Dataset ds : sm.getDatasets()) { + sb.append(" datasets[\"").append(ds.getName()).append("\"] = load_") + .append(ds.getName()).append("(spark)\n"); + } + sb.append(" return datasets\n\n"); + } + + // ----------------------------------------------------------------------- + // Join helpers + // ----------------------------------------------------------------------- + + private void generateJoinHelpers(StringBuilder sb, SemanticModel sm) { + for (Relationship rel : sm.getRelationships()) { + generateJoinHelper(sb, rel); + } + } + + private void generateJoinHelper(StringBuilder sb, Relationship rel) { + String fromDs = rel.getFrom(); + String toDs = rel.getTo(); + String funcName = "join_" + rel.getName(); + + sb.append("\ndef ").append(funcName).append("(") + .append(fromDs).append("_df: DataFrame, ") + .append(toDs).append("_df: DataFrame) -> DataFrame:\n"); + sb.append(" \"\"\"\n"); + sb.append(" Join ").append(fromDs).append(" -> ").append(toDs).append("\n"); + sb.append(" \"\"\"\n"); + + List fromCols = rel.getFromColumns(); + List toCols = rel.getToColumns(); + + if (fromCols.size() == 1) { + sb.append(" return ").append(fromDs).append("_df.join(") + .append(toDs).append("_df, ") + .append(fromDs).append("_df[\"").append(fromCols.get(0)).append("\"] == ") + .append(toDs).append("_df[\"").append(toCols.get(0)).append("\"], \"inner\")\n\n"); + } else { + String condition = buildCompositeCondition(fromDs, toDs, fromCols, toCols); + sb.append(" condition = ").append(condition).append("\n"); + sb.append(" return ").append(fromDs).append("_df.join(") + .append(toDs).append("_df, condition, \"inner\")\n\n"); + } + } + + private String buildCompositeCondition(String fromDs, String toDs, + List fromCols, List toCols) { + StringBuilder condition = new StringBuilder(); + for (int i = 0; i < fromCols.size(); i++) { + if (i > 0) { + condition.append(" & "); + } + condition.append(fromDs).append("_df[\"").append(fromCols.get(i)).append("\"] == ") + .append(toDs).append("_df[\"").append(toCols.get(i)).append("\"]"); + } + return condition.toString(); + } + + // ----------------------------------------------------------------------- + // Metrics + // ----------------------------------------------------------------------- + + private void generateMetrics(StringBuilder sb, SemanticModel sm) { + for (Metric metric : sm.getMetrics()) { + generateMetric(sb, metric); + } + } + + private void generateMetric(StringBuilder sb, Metric metric) { + String expr = pickExpression(metric.getExpressions()); + if (expr == null) { + return; + } + + String funcName = "compute_" + metric.getName(); + sb.append("\ndef ").append(funcName).append("(spark: SparkSession) -> DataFrame:\n"); + sb.append(" \"\"\"\n"); + sb.append(" Metric: ").append(metric.getName()).append("\n"); + if (metric.getDescription() != null) { + sb.append(" ").append(metric.getDescription()).append("\n"); + } + sb.append(" Expression: ").append(expr).append("\n"); + sb.append(" \"\"\"\n"); + sb.append(" return spark.sql(\"SELECT ").append(escapeString(expr)) + .append(" AS ").append(metric.getName()).append("\")\n\n"); + } + + // ----------------------------------------------------------------------- + // Main block + // ----------------------------------------------------------------------- + + private void generateMain(StringBuilder sb, SemanticModel sm) { + sb.append("\nif __name__ == \"__main__\":\n"); + sb.append(" spark = get_spark()\n"); + sb.append(" print(f\"Spark session started: {spark.sparkContext.appName}\")\n\n"); + + sb.append(" # Load all datasets\n"); + sb.append(" dfs = load_all_datasets(spark)\n"); + sb.append(" for name, df in dfs.items():\n"); + sb.append(" print(f\"Dataset {name}: {df.count()} rows\")\n\n"); + + if (!sm.getRelationships().isEmpty()) { + Relationship rel = sm.getRelationships().get(0); + sb.append(" # Example join: ").append(rel.getName()).append("\n"); + sb.append(" joined = join_").append(rel.getName()) + .append("(dfs[\"").append(rel.getFrom()) + .append("\"], dfs[\"").append(rel.getTo()).append("\"])\n"); + sb.append(" joined.show(5)\n\n"); + } + + sb.append(" spark.stop()\n"); + } + + // ----------------------------------------------------------------------- + // Utilities + // ----------------------------------------------------------------------- + + /** + * Pick the expression for the preferred dialect, falling back to the first available. + */ + private String pickExpression(List expressions) { + if (expressions == null || expressions.isEmpty()) { + return null; + } + for (DialectExpression de : expressions) { + if (dialect.equals(de.getDialect())) { + return de.getExpression(); + } + } + return expressions.get(0).getExpression(); + } + + private String escapeString(String s) { + return s.replace("\\", "\\\\").replace("\"", "\\\""); + } +} diff --git a/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java b/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java new file mode 100644 index 0000000..a205663 --- /dev/null +++ b/converters/spark/src/main/java/org/osi/converter/spark/model/OsiModel.java @@ -0,0 +1,155 @@ +package org.osi.converter.spark.model; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Java representation of an OSI semantic model parsed from YAML. + */ +public class OsiModel { + + private String version; + private List semanticModels = new ArrayList<>(); + + public String getVersion() { + return version; + } + + public void setVersion(String version) { + this.version = version; + } + + public List getSemanticModels() { + return semanticModels; + } + + public void setSemanticModels(List semanticModels) { + this.semanticModels = semanticModels; + } + + // ----------------------------------------------------------------------- + // Nested model classes + // ----------------------------------------------------------------------- + + public static class SemanticModel { + private String name; + private String description; + private List datasets = new ArrayList<>(); + private List relationships = new ArrayList<>(); + private List metrics = new ArrayList<>(); + + public String getName() { return name; } + public void setName(String name) { this.name = name; } + + public String getDescription() { return description; } + public void setDescription(String description) { this.description = description; } + + public List getDatasets() { return datasets; } + public void setDatasets(List datasets) { this.datasets = datasets; } + + public List getRelationships() { return relationships; } + public void setRelationships(List relationships) { this.relationships = relationships; } + + public List getMetrics() { return metrics; } + public void setMetrics(List metrics) { this.metrics = metrics; } + } + + public static class Dataset { + private String name; + private String source; + private List primaryKey = new ArrayList<>(); + private String description; + private List fields = new ArrayList<>(); + + public String getName() { return name; } + public void setName(String name) { this.name = name; } + + public String getSource() { return source; } + public void setSource(String source) { this.source = source; } + + public List getPrimaryKey() { return primaryKey; } + public void setPrimaryKey(List primaryKey) { this.primaryKey = primaryKey; } + + public String getDescription() { return description; } + public void setDescription(String description) { this.description = description; } + + public List getFields() { return fields; } + public void setFields(List fields) { this.fields = fields; } + } + + public static class Field { + private String name; + private String description; + private List expressions = new ArrayList<>(); + private boolean isTime; + + public String getName() { return name; } + public void setName(String name) { this.name = name; } + + public String getDescription() { return description; } + public void setDescription(String description) { this.description = description; } + + public List getExpressions() { return expressions; } + public void setExpressions(List expressions) { this.expressions = expressions; } + + public boolean isTime() { return isTime; } + public void setTime(boolean time) { isTime = time; } + } + + public static class DialectExpression { + private String dialect; + private String expression; + + public DialectExpression() {} + + public DialectExpression(String dialect, String expression) { + this.dialect = dialect; + this.expression = expression; + } + + public String getDialect() { return dialect; } + public void setDialect(String dialect) { this.dialect = dialect; } + + public String getExpression() { return expression; } + public void setExpression(String expression) { this.expression = expression; } + } + + public static class Relationship { + private String name; + private String from; + private String to; + private List fromColumns = new ArrayList<>(); + private List toColumns = new ArrayList<>(); + + public String getName() { return name; } + public void setName(String name) { this.name = name; } + + public String getFrom() { return from; } + public void setFrom(String from) { this.from = from; } + + public String getTo() { return to; } + public void setTo(String to) { this.to = to; } + + public List getFromColumns() { return fromColumns; } + public void setFromColumns(List fromColumns) { this.fromColumns = fromColumns; } + + public List getToColumns() { return toColumns; } + public void setToColumns(List toColumns) { this.toColumns = toColumns; } + } + + public static class Metric { + private String name; + private String description; + private List expressions = new ArrayList<>(); + + public String getName() { return name; } + public void setName(String name) { this.name = name; } + + public String getDescription() { return description; } + public void setDescription(String description) { this.description = description; } + + public List getExpressions() { return expressions; } + public void setExpressions(List expressions) { this.expressions = expressions; } + } +} diff --git a/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java b/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java new file mode 100644 index 0000000..ecd7898 --- /dev/null +++ b/converters/spark/src/test/java/org/osi/converter/spark/OsiSparkConverterTest.java @@ -0,0 +1,152 @@ +package org.osi.converter.spark; + +import org.junit.jupiter.api.Test; +import org.osi.converter.spark.model.OsiModel; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; + +class OsiSparkConverterTest { + + private static final String MINIMAL_MODEL = + "version: \"0.1.1\"\n" + + "\n" + + "semantic_model:\n" + + " - name: test_model\n" + + " description: A test model\n" + + " datasets:\n" + + " - name: orders\n" + + " source: db.public.orders\n" + + " primary_key: [order_id]\n" + + " description: Order fact table\n" + + " fields:\n" + + " - name: order_id\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: order_id\n" + + " - name: total_amount\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: quantity * unit_price\n" + + " description: Computed total\n" + + " - name: customer\n" + + " source: db.public.customer\n" + + " primary_key: [customer_id]\n" + + " fields:\n" + + " - name: customer_id\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: customer_id\n" + + " - name: full_name\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: \"first_name || ' ' || last_name\"\n" + + " relationships:\n" + + " - name: orders_to_customer\n" + + " from: orders\n" + + " to: customer\n" + + " from_columns: [customer_id]\n" + + " to_columns: [customer_id]\n" + + " metrics:\n" + + " - name: total_revenue\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: SUM(orders.total_amount)\n" + + " description: Total revenue across all orders\n"; + + @Test + void testParseMinimalModel() { + OsiModelParser parser = new OsiModelParser(); + OsiModel model = parser.parse( + new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8))); + + assertEquals("0.1.1", model.getVersion()); + assertEquals(1, model.getSemanticModels().size()); + + OsiModel.SemanticModel sm = model.getSemanticModels().get(0); + assertEquals("test_model", sm.getName()); + assertEquals(2, sm.getDatasets().size()); + assertEquals(1, sm.getRelationships().size()); + assertEquals(1, sm.getMetrics().size()); + } + + @Test + void testParseDatasetFields() { + OsiModelParser parser = new OsiModelParser(); + OsiModel model = parser.parse( + new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8))); + + OsiModel.Dataset orders = model.getSemanticModels().get(0).getDatasets().get(0); + assertEquals("orders", orders.getName()); + assertEquals("db.public.orders", orders.getSource()); + assertEquals(2, orders.getFields().size()); + + OsiModel.Field computed = orders.getFields().get(1); + assertEquals("total_amount", computed.getName()); + assertEquals("quantity * unit_price", computed.getExpressions().get(0).getExpression()); + } + + @Test + void testGenerateContainsExpectedFunctions() { + OsiModelParser parser = new OsiModelParser(); + OsiModel model = parser.parse( + new ByteArrayInputStream(MINIMAL_MODEL.getBytes(StandardCharsets.UTF_8))); + + SparkCodeGenerator generator = new SparkCodeGenerator("ANSI_SQL"); + String code = generator.generate(model); + + // Dataset loaders + assertTrue(code.contains("def load_orders(spark: SparkSession)")); + assertTrue(code.contains("def load_customer(spark: SparkSession)")); + assertTrue(code.contains("def load_all_datasets(spark: SparkSession)")); + + // Computed columns + assertTrue(code.contains("df.withColumn(\"total_amount\", F.expr(\"quantity * unit_price\"))")); + assertTrue(code.contains("df.withColumn(\"full_name\", F.expr(\"first_name || ' ' || last_name\"))")); + + // Temp views + assertTrue(code.contains("df.createOrReplaceTempView(\"orders\")")); + + // Join helper + assertTrue(code.contains("def join_orders_to_customer(")); + + // Metric + assertTrue(code.contains("def compute_total_revenue(spark: SparkSession)")); + assertTrue(code.contains("SUM(orders.total_amount) AS total_revenue")); + } + + @Test + void testGenerateWithDatabricksDialect() { + String multiDialect = + "version: \"0.1.1\"\n" + + "semantic_model:\n" + + " - name: multi\n" + + " datasets:\n" + + " - name: sales\n" + + " source: catalog.schema.sales\n" + + " fields:\n" + + " - name: amount\n" + + " expression:\n" + + " dialects:\n" + + " - dialect: ANSI_SQL\n" + + " expression: amount\n" + + " - dialect: DATABRICKS\n" + + " expression: \"CAST(amount AS DECIMAL(18,2))\"\n"; + + OsiModelParser parser = new OsiModelParser(); + OsiModel model = parser.parse( + new ByteArrayInputStream(multiDialect.getBytes(StandardCharsets.UTF_8))); + + SparkCodeGenerator generator = new SparkCodeGenerator("DATABRICKS"); + String code = generator.generate(model); + + assertTrue(code.contains("CAST(amount AS DECIMAL(18,2))")); + } +}