diff --git a/README.md b/README.md index df2972f..9169e3a 100644 --- a/README.md +++ b/README.md @@ -16,4 +16,8 @@ Then run `uv sync` Then download a Tableau archive (.twbx or .twb) Then run `uv run sqlparse -f file/to/parse/archive.twb(x) -r custom_report_name -o` to extract the report +As many modern data stacks use dbt, you will be asked to add the manifest.json of your dbt project. +It is not mandatory but the report will provide more accurate results. +The file will be cached and can be used, replaced or deleted. + Type `uv run sqlparse --help` for more details. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ee0dcae..ecd178c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tableau-sql-parser" -version = "0.2.2" +version = "0.3.1" description = "A tool for parsing Tableau custom SQL queries and extracting useful information" authors = [ { name = "Florian Drouet", email = "florian.drouet78@gmail.com" } diff --git a/tableau_sql_parser/__init__.py b/tableau_sql_parser/__init__.py index e69de29..5bdc0d8 100644 --- a/tableau_sql_parser/__init__.py +++ b/tableau_sql_parser/__init__.py @@ -0,0 +1,2 @@ +APP_NAME = "tableau_sql_parser" +CACHE_FILENAME = "manifest.json" diff --git a/tableau_sql_parser/cli.py b/tableau_sql_parser/cli.py index 45f2e2d..109734e 100644 --- a/tableau_sql_parser/cli.py +++ b/tableau_sql_parser/cli.py @@ -1,7 +1,7 @@ import click from tableau_sql_parser.tableau_workbook import TableauWorkbook -from tableau_sql_parser.utils import generate_report +from tableau_sql_parser.utils import generate_report, resolve_manifest_path @click.command() @@ -28,7 +28,14 @@ ) def main(file_to_parse: str, report_name: str, is_output: bool) -> None: click.echo(f"File name is: {file_to_parse} and report name is: {report_name}") - my_workbook = TableauWorkbook(filename=file_to_parse, report_name=report_name) + + manifest_path = resolve_manifest_path() + + my_workbook = TableauWorkbook( + filename=file_to_parse, + report_name=report_name, + manifest_path=manifest_path + ) tables_names, column_names, number_queries = my_workbook._generate_output() if is_output: generate_report( diff --git a/tableau_sql_parser/dbt_manifest_parser.py b/tableau_sql_parser/dbt_manifest_parser.py new file mode 100644 index 0000000..c3e5f6e --- /dev/null +++ b/tableau_sql_parser/dbt_manifest_parser.py @@ -0,0 +1,68 @@ +import json +from pathlib import Path + + +class DbtManifestParser: + def __init__(self, manifest_path: Path) -> None: + self.manifest_path = manifest_path + self.manifest = self._load_manifest() + self.dbt_objects = self._extract_dbt_objects() + + def _load_manifest(self) -> dict: + """Load and parse the dbt manifest.json file.""" + if not self.manifest_path: + raise FileNotFoundError(f"manifest.json not found at {self.manifest_path}") + with open(self.manifest_path) as f: + return json.load(f) + + def _extract_dbt_objects(self) -> list[dict[str, str]]: + """ + Extract dbt-generated objects (models, seeds, sources) from the manifest. + + Returns: + List of dicts with fields: type, schema, table + """ + results = [] + + nodes = self.manifest.get("nodes", {}) + sources = self.manifest.get("sources", {}) + + # Extract models and seeds + for node in nodes.values(): + resource_type = node.get("resource_type") + if resource_type in {"model", "seed"}: + schema = node.get("schema") + table = node.get("alias") + if schema and table: + results.append({ + "type": resource_type, + "schema": schema, + "table": table + }) + + # Extract sources + for source in sources.values(): + schema = source.get("schema") + table = source.get("identifier") + if schema and table: + results.append({ + "type": "source", + "schema": schema, + "table": table + }) + + return results + + def get_schema_table_strings(self) -> list[str]: + """Return distinct list of schema.table strings.""" + return sorted({f'{obj["schema"]}.{obj["table"]}' for obj in self.dbt_objects}) + + def get_tables(self) -> list[str]: + """Return distinct list of tables.""" + return sorted({obj["table"] for obj in self.dbt_objects}) + + def get_all_table_names(self) -> list[str]: + """Return all table names from the manifest.""" + schema_table_strings = self.get_schema_table_strings() + tables = self.get_tables() + return [*schema_table_strings, *tables] diff --git a/tableau_sql_parser/output_formatting.py b/tableau_sql_parser/output_formatting.py index deea287..6207e37 100644 --- a/tableau_sql_parser/output_formatting.py +++ b/tableau_sql_parser/output_formatting.py @@ -1,10 +1,17 @@ class OutputFormatting: - def __init__(self, report_name: str, alias: dict, columns: dict) -> None: + def __init__( + self, + report_name: str, + alias: dict, + columns: dict, + dbt_table_names: list + ) -> None: self.report_name = report_name self.tables_names = [] self.column_names = [] self.alias = alias self.columns = columns + self.dbt_table_names = dbt_table_names self.get_column_names_all() self.get_tables_names() @@ -20,12 +27,30 @@ def get_column_names(alias: dict, columns: dict) -> list: column_names_full.append(potential_alias) return column_names_full + @staticmethod + def filter_dbt_manifest(column_names: list, dbt_columns: list) -> list[str]: + """ + This function filters the dbt manifest to get the relevant information + for the report. + """ + filtered_elements = [] + for column in column_names: + splitted_column = column.rsplit(".", 1)[0] + if splitted_column in dbt_columns: + filtered_elements.append(column) + return filtered_elements + + def get_column_names_all(self) -> None: temp_column_names = [] for i in range(0, len(self.alias)): temp_column_names.extend( self.get_column_names(alias=self.alias[i], columns=self.columns[i]) ) + if self.dbt_table_names: + temp_column_names = self.filter_dbt_manifest( + column_names=temp_column_names, dbt_columns=self.dbt_table_names + ) self.column_names = sorted([*set(temp_column_names)]) def get_tables_names(self) -> None: diff --git a/tableau_sql_parser/tableau_workbook.py b/tableau_sql_parser/tableau_workbook.py index 822297b..0e92467 100644 --- a/tableau_sql_parser/tableau_workbook.py +++ b/tableau_sql_parser/tableau_workbook.py @@ -5,6 +5,7 @@ import lxml.etree import sqlfluff +from tableau_sql_parser.dbt_manifest_parser import DbtManifestParser from tableau_sql_parser.output_formatting import OutputFormatting from tableau_sql_parser.recursive_search import RecursiveSearch @@ -14,7 +15,12 @@ class TableauWorkbook: Defines a workbook object from a filename. """ - def __init__(self, filename: str, report_name: str) -> None: + def __init__( + self, + filename: str, + report_name: str, + manifest_path: str = None + ) -> None: self.filename = os.path.normpath(filename) self.report_name = report_name self.xml = self._get_xml() @@ -25,6 +31,12 @@ def __init__(self, filename: str, report_name: str) -> None: self.columns, self.alias, ) = self._recursive_search_sql() + if manifest_path: + self.dbt_table_names = DbtManifestParser( + manifest_path=manifest_path + ).get_all_table_names() + else: + self.dbt_table_names = None def _get_xml(self) -> lxml.etree._Element: """ @@ -97,5 +109,6 @@ def _generate_output(self) -> tuple[list, list, int]: report_name=self.report_name, alias=self.alias, columns=self.columns, + dbt_table_names=self.dbt_table_names, ) return report.tables_names, report.column_names, number_queries_analyzed diff --git a/tableau_sql_parser/utils.py b/tableau_sql_parser/utils.py index 98c4253..815e1ae 100644 --- a/tableau_sql_parser/utils.py +++ b/tableau_sql_parser/utils.py @@ -1,4 +1,10 @@ +import os +import shutil + import click +from platformdirs import user_cache_dir + +from tableau_sql_parser import APP_NAME, CACHE_FILENAME def tree_output(column_names: list) -> str: @@ -48,3 +54,47 @@ def generate_report( f.write("\nColumns are:\n") f.write(tree) f.write("---\n") + + +def resolve_manifest_path() -> str: + """Manage cached manifest.json: use, replace, or delete.""" + cache_dir = user_cache_dir(APP_NAME) + os.makedirs(cache_dir, exist_ok=True) + cached_manifest_path = os.path.join(cache_dir, CACHE_FILENAME) + + if os.path.exists(cached_manifest_path): + click.echo("📦 Cached manifest.json file detected.") + action = click.prompt( + "Do you want to use the cached file? (use / replace / delete)", + type=click.Choice(["use", "replace", "delete"]), + default="use" + ) + + if action == "use": + click.echo(f"Using cached manifest: {cached_manifest_path}") + return cached_manifest_path + + elif action == "replace": + manifest_path = click.prompt( + "Enter the path to the new manifest file", + type=click.Path(exists=True) + ) + shutil.copy(manifest_path, cached_manifest_path) + click.echo(f"✅ Replaced cached manifest with: {manifest_path}") + return cached_manifest_path + + elif action == "delete": + os.remove(cached_manifest_path) + click.echo("🗑️ Deleted cached manifest file.") + return None + else: + if click.confirm("Do you want to add a manifest file?", default=False): + manifest_path = click.prompt( + "Enter the path to the manifest file", + type=click.Path(exists=True) + ) + shutil.copy(manifest_path, cached_manifest_path) + click.echo(f"✅ Cached manifest: {cached_manifest_path}") + return cached_manifest_path + + return None diff --git a/tests/test_dbt_manifest_parser.py b/tests/test_dbt_manifest_parser.py new file mode 100644 index 0000000..2304cc1 --- /dev/null +++ b/tests/test_dbt_manifest_parser.py @@ -0,0 +1,86 @@ +import json +from pathlib import Path + +import pytest + +from tableau_sql_parser.dbt_manifest_parser import DbtManifestParser + + +@pytest.fixture +def manifest_file(tmp_path: str) -> Path: + """Creates a mock manifest.json file for testing.""" + manifest_data = { + "nodes": { + "model.test.model_a": { + "resource_type": "model", + "schema": "analytics", + "alias": "model_a" + }, + "seed.test.seed_a": { + "resource_type": "seed", + "schema": "public", + "alias": "seed_a" + } + }, + "sources": { + "source.test.source_a": { + "schema": "raw", + "identifier": "source_a" + } + } + } + path = tmp_path / "manifest.json" + path.write_text(json.dumps(manifest_data)) + return path + + +def test_load_manifest(manifest_file: Path) -> None: + parser = DbtManifestParser(manifest_file) + assert isinstance(parser.manifest, dict) + assert "nodes" in parser.manifest + assert "sources" in parser.manifest + + +def test_extract_dbt_objects(manifest_file: Path) -> None: + parser = DbtManifestParser(manifest_file) + expected = [ + {"type": "model", "schema": "analytics", "table": "model_a"}, + {"type": "seed", "schema": "public", "table": "seed_a"}, + {"type": "source", "schema": "raw", "table": "source_a"}, + ] + assert parser.dbt_objects == expected + + +def test_get_schema_table_strings(manifest_file: Path) -> None: + parser = DbtManifestParser(manifest_file) + expected = sorted([ + "analytics.model_a", + "public.seed_a", + "raw.source_a" + ]) + assert parser.get_schema_table_strings() == expected + + +def test_get_tables(manifest_file: Path) -> None: + parser = DbtManifestParser(manifest_file) + expected = sorted(["model_a", "seed_a", "source_a"]) + assert parser.get_tables() == expected + + +def test_get_all_table_names(manifest_file: Path) -> None: + parser = DbtManifestParser(manifest_file) + expected = sorted([ + "analytics.model_a", + "public.seed_a", + "raw.source_a", + "model_a", + "seed_a", + "source_a" + ]) + assert sorted(parser.get_all_table_names()) == expected + + +def test_missing_file_raises(tmp_path: str) -> None: + non_existent_path = tmp_path / "missing_manifest.json" + with pytest.raises(FileNotFoundError): + DbtManifestParser(non_existent_path) diff --git a/uv.lock b/uv.lock index 0bf51b6..0128141 100644 --- a/uv.lock +++ b/uv.lock @@ -357,7 +357,7 @@ wheels = [ [[package]] name = "tableau-sql-parser" -version = "0.1.2" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "lxml" },