From 88e210b250d458cf531caffd66a15418d09f021f Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 9 Mar 2026 15:02:14 +0000 Subject: [PATCH 01/24] feat: Added a FieldMapping class to handle transformations for physical field names --- mario/mapping.py | 25 +++++++++++++++++++++++++ mario/utils.py | 12 ++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 mario/mapping.py diff --git a/mario/mapping.py b/mario/mapping.py new file mode 100644 index 0000000..9337c8f --- /dev/null +++ b/mario/mapping.py @@ -0,0 +1,25 @@ +from typing import List + +from mario.utils import to_snake_case + + +class FieldMapping: + """ + Utility class for holding a set of logical-to-physical item mappings + """ + + def __init__(self, query_format, items: List[str]): + self._format = query_format + + self.as_physical = {} + self.as_logical = {} + + for item in items: + self.as_physical[item] = self.map_item(item) + self.as_logical[self.map_item(item)] = item + + def map_item(self, item): + if self._format is None: + return item + if self._format == 'snake_case': + return to_snake_case(item) diff --git a/mario/utils.py b/mario/utils.py index ffe3149..f4a21ea 100644 --- a/mario/utils.py +++ b/mario/utils.py @@ -7,3 +7,15 @@ def append_current_date_to_file_name(file_name: str) -> str: filename = os.path.splitext(os.path.basename(file_name))[0] extension = os.path.splitext(os.path.basename(file_name))[1] return filename + '_' + date_time + extension + + +def to_snake_case(name: str) -> str: + """ + Convert a name to lowercase_with_underscores. + Example: "Academic Year" -> "academic_year". + """ + import re + name = name.strip().lower() + name = re.sub(r"[^\w]+", "_", name) + name = re.sub(r"__+", "_", name).strip("_") + return name \ No newline at end of file From c4fbc464c11fdcc4066169fac0f66ce67046665e Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 9 Mar 2026 15:04:13 +0000 Subject: [PATCH 02/24] feat: Added FieldMapping to Data Extractor and Query Builder --- mario/data_extractor.py | 11 +++++++++-- mario/query_builder.py | 9 ++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mario/data_extractor.py b/mario/data_extractor.py index 31d5d46..b031d6e 100644 --- a/mario/data_extractor.py +++ b/mario/data_extractor.py @@ -7,6 +7,7 @@ from mario.dataset_specification import DatasetSpecification from mario.metadata import Metadata from mario.options import CsvOptions, HyperOptions +from mario.mapping import FieldMapping logger = logging.getLogger(__name__) @@ -28,7 +29,8 @@ def __init__(self, schema: str = None, file_path: str = None, query_builder=None, - user: str = None + user: str = None, + query_format=None ): self.connection_string = connection_string self.hook = hook @@ -37,6 +39,7 @@ def __init__(self, self.file_path = file_path self.query_builder = query_builder self.user = user + self.query_format = query_format class DataExtractor: @@ -51,6 +54,10 @@ def __init__(self, self._data = None self._query = None self._total = 0 + self.mapping = FieldMapping( + query_format=configuration.query_format, + items=dataset_specification.items + ) def __load__(self): if self.configuration is not None: @@ -113,7 +120,7 @@ def __get_column_name__(self, item: str): elif meta.get_property('physical_column_name') is not None: return meta.get_property('physical_column_name') else: - return meta.name + return self.mapping.as_physical[meta.name] def __minimise_data__(self): """ Minimise data so we only keep the columns in the spec """ diff --git a/mario/query_builder.py b/mario/query_builder.py index 4093b23..cac2a55 100644 --- a/mario/query_builder.py +++ b/mario/query_builder.py @@ -7,6 +7,7 @@ from mario.data_extractor import Configuration from mario.dataset_specification import DatasetSpecification from mario.metadata import Metadata +from mario.mapping import FieldMapping logger = logging.getLogger(__name__) @@ -31,6 +32,7 @@ def __init__(self, self.configuration = configuration self.metadata = metadata self.dataset_specification = dataset_specification + self.mapping = FieldMapping(query_format=configuration.query_format, items=dataset_specification.items) def create_query(self) -> [str, List[any]]: raise NotImplementedError @@ -57,7 +59,7 @@ def create_totals_query(self, measure=None) -> [str, List[any]]: if measure is None: _sql = f'SELECT COUNT(*) FROM "' + self.configuration.schema + '"."' + self.configuration.view + '"' else: - _sql = f'SELECT SUM("'+measure+'") FROM "' + self.configuration.schema + '"."' + self.configuration.view + '"' + _sql = f'SELECT SUM("'+self.mapping.as_physical[measure]+'") FROM "' + self.configuration.schema + '"."' + self.configuration.view + '"' _params = [] return [_sql, _params] @@ -127,12 +129,13 @@ def create_query(self) -> [str, List[any]]: # Don't include calculated fields meta = self.metadata.get_metadata(field) if not meta.get_property('formula'): - select_fields.append(field) + select_fields.append(self.mapping.as_physical[field]) group_fields = select_fields.copy() # remove measures from regular select measures = [] - for measure in self.dataset_specification.measures: + for measure_field in self.dataset_specification.measures: + measure = self.mapping.as_physical[measure_field] if measure in select_fields: select_fields.remove(measure) group_fields.remove(measure) From 6d476a7b790be996c4d7029eaeafa963c624c925 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 9 Mar 2026 15:05:47 +0000 Subject: [PATCH 03/24] feat: Added Athena data extractor --- mario/athena.py | 49 ++++++++++++++++++++++++++++++ requirements.txt | 3 +- setup.py | 3 +- test/test_athena_extractor.py | 56 +++++++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 mario/athena.py create mode 100644 test/test_athena_extractor.py diff --git a/mario/athena.py b/mario/athena.py new file mode 100644 index 0000000..1c1e759 --- /dev/null +++ b/mario/athena.py @@ -0,0 +1,49 @@ +from pyathena import connect +from mario.data_extractor import StreamingDataExtractor, Configuration +import logging + +logger = logging.getLogger(__name__) + + +class AthenaConfiguration(Configuration): + """ + Extended configuration + """ + + def __init__(self): + super().__init__() + self.aws_s3_staging_dir = 's3://celeste-iceberg/athena-results/' + self.aws_region_name = 'eu-west-2' + self.aws_athena_workgroup = 'primary' + self.catalog = 'awsdatacatalog' + self.query_format = 'snake_case' + + +class AthenaStreamingDataExtractor(StreamingDataExtractor): + """ + Streaming extractor using Athena + PyAthena. + Extends StreamingDataExtractor; only get_connection() is Athena-specific. + """ + + def __init__(self, configuration: AthenaConfiguration, dataset_specification, metadata): + super().__init__(configuration, dataset_specification, metadata) + self.configuration = configuration + + def get_connection(self): + """ + PyAthena provides a DBAPI connection compatible with pandas.read_sql() including chunksize. + """ + cfg = self.configuration + + if cfg.hook: + # If the user provided a hook, delegate to it + return cfg.hook.get_conn() + + # Expect these in configuration: + return connect( + s3_staging_dir=cfg.aws_s3_staging_dir, + region_name=cfg.aws_region_name, + work_group=cfg.aws_athena_workgroup, + schema_name=cfg.schema, + catalog_name=cfg.catalog + ) diff --git a/requirements.txt b/requirements.txt index 1b7fbb1..0e1dbc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ tableauhyperapi~=0.0.18161 tableau-builder==0.18 pypika sqlalchemy -openpyxl \ No newline at end of file +openpyxl +pyathena \ No newline at end of file diff --git a/setup.py b/setup.py index e989d3a..28673cd 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ ], extras_require={ 'Airflow': ['apache-airflow-providers-common-sql'], - 'Tableau': ['pantab', 'tableauhyperapi', 'tableau-builder==0.18'] + 'Tableau': ['pantab', 'tableauhyperapi', 'tableau-builder==0.18'], + 'Athena': ['pyathena'] } ) \ No newline at end of file diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py new file mode 100644 index 0000000..dc020ae --- /dev/null +++ b/test/test_athena_extractor.py @@ -0,0 +1,56 @@ +from mario.athena import AthenaConfiguration, AthenaStreamingDataExtractor +from mario.dataset_specification import DatasetSpecification +from mario.query_builder import SubsetQueryBuilder, ViewBasedQueryBuilder +from mario.metadata import Metadata, Item +import os + + +def test_athena_stream(): + os.makedirs('output/test_athena', exist_ok=True) + + dataset = DatasetSpecification() + dataset.dimensions = [ + "Academic Year", + "Mode of study", + "Country of HE provider" + ] + dataset.measures = ['Number'] + dataset.name = 'student_open_data' + metadata = Metadata() + academic_year = Item() + academic_year.name = 'Academic Year' + mode_of_study = Item() + mode_of_study.name = 'Mode of study' + country_of_he_provider = Item() + country_of_he_provider.name = 'Country of HE provider' + number_field = Item() + number_field.name = 'Number' + metadata.add_item(academic_year) + metadata.add_item(mode_of_study) + metadata.add_item(number_field) + metadata.add_item(country_of_he_provider) + + cfg = AthenaConfiguration() + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaStreamingDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + extractor.stream_sql_to_csv( + file_path='output/test_athena/test.csv', + minimise=True, + compress_using_gzip=False, + do_not_modify_source=True + ) + + extractor.stream_sql_to_hyper( + file_path='output/test_athena/test.hyper', + minimise=True, + compress_using_gzip=False, + do_not_modify_source=True + ) From c8b0a35901f387f3e4bc54cf7ebb0b8598b53ab3 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 11 Mar 2026 11:43:19 +0000 Subject: [PATCH 04/24] feat: reverse physical->logical mapping when writing data --- mario/data_extractor.py | 18 ++++++------------ mario/mapping.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mario/data_extractor.py b/mario/data_extractor.py index b031d6e..6d6a654 100644 --- a/mario/data_extractor.py +++ b/mario/data_extractor.py @@ -83,7 +83,9 @@ def __load_from_sql__(self): logger.info("Executing query") from sqlalchemy import create_engine engine = create_engine(self.configuration.connection_string) - self._data = pd.read_sql(sql=self._query[0], con=engine.connect(), params=self._query[1]) + df = pd.read_sql(sql=self._query[0], con=engine.connect(), params=self._query[1]) + df = self.mapping.df_to_logical(df) + self._data = df def __build_query__(self): logger.info("Building query") @@ -112,22 +114,12 @@ def __load_from_hyper__(self): table=table ) - def __get_column_name__(self, item: str): - """ Returns the column name for a metadata item""" - meta = self.metadata.get_metadata(item) - if meta.get_property('output_name') is not None: - return meta.get_property('output_name') - elif meta.get_property('physical_column_name') is not None: - return meta.get_property('physical_column_name') - else: - return self.mapping.as_physical[meta.name] - def __minimise_data__(self): """ Minimise data so we only keep the columns in the spec """ columns_to_keep = [] for item in self.dataset_specification.items: if self.metadata.get_metadata(item) and not self.metadata.get_metadata(item).get_property('formula'): - columns_to_keep.append(self.__get_column_name__(item)) + columns_to_keep.append(item) self._data = self._data[columns_to_keep] def __get_measure__(self, measure=None): @@ -403,6 +395,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): table_name = TableName(options.schema, options.table) row_counter = 0 for df in pd.read_sql(self._query[0], connection, chunksize=options.chunk_size): + df = self.mapping.df_to_logical(df) if options.validate or options.minimise or options.include_row_numbers: self._data = df if options.validate: @@ -429,6 +422,7 @@ def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, * header = True for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size): + df = self.mapping.df_to_logical(df) if options.validate or options.minimise: self._data = df if options.validate: diff --git a/mario/mapping.py b/mario/mapping.py index 9337c8f..d7584d2 100644 --- a/mario/mapping.py +++ b/mario/mapping.py @@ -1,5 +1,7 @@ from typing import List +import pandas as pd + from mario.utils import to_snake_case @@ -23,3 +25,11 @@ def map_item(self, item): return item if self._format == 'snake_case': return to_snake_case(item) + + def df_to_logical(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Renames the columns in the dataframe from their physical to their logical names + :param df: pandas Dataframe + :return: pandas Dataframe + """ + return df.rename(columns=self.as_logical) From 754a1167d5400e79addce49b7c1bb2d6841e1246 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 11 Mar 2026 11:43:29 +0000 Subject: [PATCH 05/24] test: added testing for column spec --- test/test_athena_extractor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index dc020ae..9ae89a3 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -1,8 +1,9 @@ from mario.athena import AthenaConfiguration, AthenaStreamingDataExtractor from mario.dataset_specification import DatasetSpecification -from mario.query_builder import SubsetQueryBuilder, ViewBasedQueryBuilder +from mario.query_builder import SubsetQueryBuilder from mario.metadata import Metadata, Item import os +import pandas as pd def test_athena_stream(): @@ -48,6 +49,13 @@ def test_athena_stream(): do_not_modify_source=True ) + # Load and test + df = pd.read_csv('output/test_athena/test.csv') + for column in dataset.dimensions: + assert column in df.columns + assert 'Number' in df.columns + assert len(df.columns) == len(dataset.items) + extractor.stream_sql_to_hyper( file_path='output/test_athena/test.hyper', minimise=True, From 6f3385cbdc25bb16741ef16e37d38132b02e3a60 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Thu, 12 Mar 2026 08:50:11 +0000 Subject: [PATCH 06/24] test: added testing for totals --- test/test_athena_extractor.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 9ae89a3..0b6a3ce 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -3,12 +3,11 @@ from mario.query_builder import SubsetQueryBuilder from mario.metadata import Metadata, Item import os +import shutil import pandas as pd -def test_athena_stream(): - os.makedirs('output/test_athena', exist_ok=True) - +def get_test_conf(): dataset = DatasetSpecification() dataset.dimensions = [ "Academic Year", @@ -30,7 +29,14 @@ def test_athena_stream(): metadata.add_item(mode_of_study) metadata.add_item(number_field) metadata.add_item(country_of_he_provider) + return dataset, metadata + +def test_athena_stream(): + shutil.rmtree('output/test_athena', ignore_errors=True) + os.makedirs('output/test_athena', exist_ok=True) + + dataset, metadata = get_test_conf() cfg = AthenaConfiguration() cfg.query_builder = SubsetQueryBuilder cfg.schema = 'demo' @@ -62,3 +68,22 @@ def test_athena_stream(): compress_using_gzip=False, do_not_modify_source=True ) + + +def test_athena_count(): + + dataset, metadata = get_test_conf() + cfg = AthenaConfiguration() + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaStreamingDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + total = extractor.get_total(measure=dataset.measures[0]) + print("total", total) # 28,733,910 + From 0c1d143ef2c60620d23d9d32250f452686159c21 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Thu, 12 Mar 2026 09:18:23 +0000 Subject: [PATCH 07/24] fix: dropped some of the defaults and skip tests if we don't have access to AWS --- mario/athena.py | 4 +-- test/test_athena_extractor.py | 65 ++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/mario/athena.py b/mario/athena.py index 1c1e759..dad7e0a 100644 --- a/mario/athena.py +++ b/mario/athena.py @@ -12,8 +12,8 @@ class AthenaConfiguration(Configuration): def __init__(self): super().__init__() - self.aws_s3_staging_dir = 's3://celeste-iceberg/athena-results/' - self.aws_region_name = 'eu-west-2' + self.aws_s3_staging_dir = None + self.aws_region_name = None self.aws_athena_workgroup = 'primary' self.catalog = 'awsdatacatalog' self.query_format = 'snake_case' diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 0b6a3ce..12d0838 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -1,3 +1,5 @@ +import pytest + from mario.athena import AthenaConfiguration, AthenaStreamingDataExtractor from mario.dataset_specification import DatasetSpecification from mario.query_builder import SubsetQueryBuilder @@ -29,15 +31,23 @@ def get_test_conf(): metadata.add_item(mode_of_study) metadata.add_item(number_field) metadata.add_item(country_of_he_provider) - return dataset, metadata + + config = AthenaConfiguration() + config.aws_s3_staging_dir = 's3://celeste-iceberg/athena-results/' + config.aws_region_name = 'eu-west-2' + + return dataset, metadata, config -def test_athena_stream(): +def test_athena_stream_sql_to_csv(): + # Skip this test if we don't have an AWS profile + if not os.environ.get('AWS_PROFILE'): + pytest.skip("Skipping Athena test as no AWS profile configured") + shutil.rmtree('output/test_athena', ignore_errors=True) os.makedirs('output/test_athena', exist_ok=True) - dataset, metadata = get_test_conf() - cfg = AthenaConfiguration() + dataset, metadata, cfg = get_test_conf() cfg.query_builder = SubsetQueryBuilder cfg.schema = 'demo' cfg.view = 'student_open_data' @@ -62,18 +72,14 @@ def test_athena_stream(): assert 'Number' in df.columns assert len(df.columns) == len(dataset.items) - extractor.stream_sql_to_hyper( - file_path='output/test_athena/test.hyper', - minimise=True, - compress_using_gzip=False, - do_not_modify_source=True - ) - def test_athena_count(): + # Skip this test if we don't have an AWS profile + if not os.environ.get('AWS_PROFILE'): + pytest.skip("Skipping Athena test as no AWS profile configured") - dataset, metadata = get_test_conf() - cfg = AthenaConfiguration() + + dataset, metadata, cfg = get_test_conf() cfg.query_builder = SubsetQueryBuilder cfg.schema = 'demo' cfg.view = 'student_open_data' @@ -87,3 +93,36 @@ def test_athena_count(): total = extractor.get_total(measure=dataset.measures[0]) print("total", total) # 28,733,910 + +def test_athena_save_data_as_csv(): + # Skip this test if we don't have an AWS profile + if not os.environ.get('AWS_PROFILE'): + pytest.skip("Skipping Athena test as no AWS profile configured") + + shutil.rmtree('output/test_athena', ignore_errors=True) + os.makedirs('output/test_athena', exist_ok=True) + + dataset, metadata, cfg = get_test_conf() + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaStreamingDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + extractor.save_data_as_csv( + file_path='output/test_athena/test.csv', + minimise=False, + compress_using_gzip=False, + do_not_modify_source=True + ) + + # Load and test + df = pd.read_csv('output/test_athena/test.csv') + for column in dataset.dimensions: + assert column in df.columns + assert 'Number' in df.columns + assert len(df.columns) == len(dataset.items) From 00202d0e219f8e884e39d6d31313a6d7d74623a7 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 18 Mar 2026 10:22:08 +0000 Subject: [PATCH 08/24] fix: added FieldMapping to SqlValidator, and use hooks where available instead of SqlAlchemy --- mario/validation.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mario/validation.py b/mario/validation.py index d22bc9a..1b35ecb 100644 --- a/mario/validation.py +++ b/mario/validation.py @@ -3,6 +3,7 @@ from mario.data_extractor import Configuration from mario.dataset_specification import DatasetSpecification +from mario.mapping import FieldMapping from mario.metadata import Metadata, Item logger = logging.getLogger(__name__) @@ -416,9 +417,23 @@ def __init__(self, configuration: Configuration ): super().__init__(dataset_specification, metadata) - self.connection = self.__get_connection__(configuration.connection_string) + if configuration.hook: + # If the user provided a hook, delegate to it + self.connection = configuration.hook.get_conn() + else: + self.connection = self.__get_connection__(configuration.connection_string) self.schema = configuration.schema self.view = configuration.view + self.mapping = FieldMapping(query_format=configuration.query_format, items=dataset_specification.items) + + def __get_column_name__(self, item: Item): + """ Returns the column name for a metadata item""" + if item.get_property('output_name') is not None: + return item.get_property('output_name') + elif item.get_property('physical_column_name') is not None: + return item.get_property('physical_column_name') + else: + return self.mapping.as_physical[item.name] def __get_connection__(self, connection_string: str): from sqlalchemy import create_engine @@ -492,7 +507,7 @@ def check_column_present(self, item: Item): sql = f"SELECT COLUMN_NAME as col FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{self.view}' AND TABLE_SCHEMA='{self.schema}'" df = pd.read_sql(sql, self.connection) values = df['col'].to_list() - if item.name not in values: + if self.__get_column_name__(item) not in values: self.errors.append(f"Validation error: '{item.name}' in specification is missing from dataset") return False return True From 55743b3d40664d81e9673e90104b14da7aeb8645 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 18 Mar 2026 10:22:26 +0000 Subject: [PATCH 09/24] test: Added validation test --- test/test_athena_extractor.py | 53 +++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 12d0838..20a4217 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -1,3 +1,5 @@ +from copy import copy + import pytest from mario.athena import AthenaConfiguration, AthenaStreamingDataExtractor @@ -8,6 +10,19 @@ import shutil import pandas as pd +from mario.validation import SqlValidator + +AWS_ATHENA_RESULTS_DIR = os.environ.get('AWS_ATHENA_RESULTS_DIR') +AWS_REGION = os.environ.get('AWS_REGION') + + +class MockHook: + def __init__(self, extractor: AthenaStreamingDataExtractor): + self.extractor = extractor + + def get_conn(self): + return self.extractor.get_connection() + def get_test_conf(): dataset = DatasetSpecification() @@ -23,8 +38,10 @@ def get_test_conf(): academic_year.name = 'Academic Year' mode_of_study = Item() mode_of_study.name = 'Mode of study' + mode_of_study.set_property('domain', ['Full-time', 'Part-time']) country_of_he_provider = Item() country_of_he_provider.name = 'Country of HE provider' + country_of_he_provider.set_property('domain', ['England', 'Wales', 'Scotland', 'Northern Ireland']) number_field = Item() number_field.name = 'Number' metadata.add_item(academic_year) @@ -33,8 +50,8 @@ def get_test_conf(): metadata.add_item(country_of_he_provider) config = AthenaConfiguration() - config.aws_s3_staging_dir = 's3://celeste-iceberg/athena-results/' - config.aws_region_name = 'eu-west-2' + config.aws_s3_staging_dir = AWS_ATHENA_RESULTS_DIR + config.aws_region_name = AWS_REGION return dataset, metadata, config @@ -126,3 +143,35 @@ def test_athena_save_data_as_csv(): assert column in df.columns assert 'Number' in df.columns assert len(df.columns) == len(dataset.items) + + +def test_athena_validate(): + # Skip this test if we don't have an AWS profile + if not os.environ.get('AWS_PROFILE'): + pytest.skip("Skipping Athena test as no AWS profile configured") + + shutil.rmtree('output/test_athena', ignore_errors=True) + os.makedirs('output/test_athena', exist_ok=True) + + dataset, metadata, cfg = get_test_conf() + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaStreamingDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + mock_hook = MockHook(extractor) + cfg_validator = copy(cfg) + cfg_validator.hook = mock_hook + validator = SqlValidator( + dataset_specification=dataset, + configuration=cfg_validator, + metadata=metadata + ) + validator.validate_data(allow_nulls=False) + + + From 04c4abcd90a0a8e2a2f889bd729feb9c4caf6099 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 18 Mar 2026 10:22:46 +0000 Subject: [PATCH 10/24] doc: added link to hook docs for Airflow+Athena --- mario/athena.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mario/athena.py b/mario/athena.py index dad7e0a..4aea639 100644 --- a/mario/athena.py +++ b/mario/athena.py @@ -37,6 +37,7 @@ def get_connection(self): if cfg.hook: # If the user provided a hook, delegate to it + # See: https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/_api/airflow/providers/amazon/aws/hooks/athena_sql/index.html return cfg.hook.get_conn() # Expect these in configuration: From 98aebd73eefbaf12f4599c6a39d00f3c496fb6d1 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Thu, 19 Mar 2026 11:41:00 +0000 Subject: [PATCH 11/24] test: skip Athena tests if any env vars aren't set --- test/test_athena_extractor.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 20a4217..3c3609a 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -12,6 +12,7 @@ from mario.validation import SqlValidator +AWS_PROFILE = os.environ.get('AWS_PROFILE') AWS_ATHENA_RESULTS_DIR = os.environ.get('AWS_ATHENA_RESULTS_DIR') AWS_REGION = os.environ.get('AWS_REGION') @@ -57,9 +58,9 @@ def get_test_conf(): def test_athena_stream_sql_to_csv(): - # Skip this test if we don't have an AWS profile - if not os.environ.get('AWS_PROFILE'): - pytest.skip("Skipping Athena test as no AWS profile configured") + # Skip this test if we don't have AWS env vars + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") shutil.rmtree('output/test_athena', ignore_errors=True) os.makedirs('output/test_athena', exist_ok=True) @@ -91,10 +92,8 @@ def test_athena_stream_sql_to_csv(): def test_athena_count(): - # Skip this test if we don't have an AWS profile - if not os.environ.get('AWS_PROFILE'): - pytest.skip("Skipping Athena test as no AWS profile configured") - + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") dataset, metadata, cfg = get_test_conf() cfg.query_builder = SubsetQueryBuilder @@ -112,9 +111,8 @@ def test_athena_count(): def test_athena_save_data_as_csv(): - # Skip this test if we don't have an AWS profile - if not os.environ.get('AWS_PROFILE'): - pytest.skip("Skipping Athena test as no AWS profile configured") + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") shutil.rmtree('output/test_athena', ignore_errors=True) os.makedirs('output/test_athena', exist_ok=True) @@ -146,9 +144,8 @@ def test_athena_save_data_as_csv(): def test_athena_validate(): - # Skip this test if we don't have an AWS profile - if not os.environ.get('AWS_PROFILE'): - pytest.skip("Skipping Athena test as no AWS profile configured") + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") shutil.rmtree('output/test_athena', ignore_errors=True) os.makedirs('output/test_athena', exist_ok=True) From 9e211a1c00fc103c4766c199d0276f08a200d660 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Fri, 20 Mar 2026 09:08:28 +0000 Subject: [PATCH 12/24] feat: refactored Athena as a bulk rather than streaming extractor, as the way Athena handles data doesn't really support streaming quite the way you'd expect. --- mario/athena.py | 77 ++++++++++++++++++++++- mario/mapping.py | 35 +++++++++++ mario/utils.py | 12 +++- requirements.txt | 3 +- setup.py | 2 +- test/test_athena_extractor.py | 112 ++++++++++++++++++++++------------ 6 files changed, 198 insertions(+), 43 deletions(-) diff --git a/mario/athena.py b/mario/athena.py index 4aea639..cac9376 100644 --- a/mario/athena.py +++ b/mario/athena.py @@ -1,7 +1,11 @@ from pyathena import connect -from mario.data_extractor import StreamingDataExtractor, Configuration +from mario.data_extractor import DataExtractor, Configuration import logging +from mario.mapping import rewrite_csv_header_with_fieldmapping +from mario.options import CsvOptions +from mario.utils import gzip_file + logger = logging.getLogger(__name__) @@ -19,7 +23,7 @@ def __init__(self): self.query_format = 'snake_case' -class AthenaStreamingDataExtractor(StreamingDataExtractor): +class AthenaDataExtractor(DataExtractor): """ Streaming extractor using Athena + PyAthena. Extends StreamingDataExtractor; only get_connection() is Athena-specific. @@ -48,3 +52,72 @@ def get_connection(self): schema_name=cfg.schema, catalog_name=cfg.catalog ) + + def get_total(self, measure=None): + """ + For totals when streaming data we need to run a totals SQL query separate + from the main query and use the results of this + :return: the total value of the query + TODO this is a direct copy from StreamingDataExtractor + """ + from pandas import read_sql + logger.info("Building totals query") + measure = self.__get_measure__(measure) + if self.configuration.query_builder is not None: + from mario.query_builder import QueryBuilder + query_builder: QueryBuilder = self.configuration.query_builder( + configuration=self.configuration, + metadata=self.metadata, + dataset_specification=self.dataset_specification) + totals_query = query_builder.create_totals_query(measure=measure) + else: + raise NotImplementedError + + totals_df = read_sql(totals_query[0], self.get_connection(), params=totals_query[1]) + return totals_df.iat[0, 0] + + def save_data_as_csv(self, file_path: str, **kwargs): + """ + Athena doesn't support streaming, but natively saves CSV files + in S3 as output so we really don't need to do anything else + other than run the query and download the results from S3 + :param file_path: + :param kwargs: + :return: + """ + import awswrangler as wr + import boto3 + + # Parse options + options = CsvOptions(**kwargs) + + # Build SQL + self.__build_query__() + sql = self._query[0] + cfg = self.configuration + + # 1. Run SQL via Wrangler + qid = wr.athena.start_query_execution( + sql=sql, + database=cfg.schema, + workgroup=cfg.aws_athena_workgroup, + s3_output=cfg.aws_s3_staging_dir, + ) + wr.athena.wait_query(qid) + + # 2. Get S3 CSV result path + client = boto3.client("athena") + meta = client.get_query_execution(QueryExecutionId=qid) + s3_uri = meta["QueryExecution"]["ResultConfiguration"]["OutputLocation"] + + # 3. Download raw Athena CSV + wr.s3.download(path=s3_uri, local_file=file_path) + + # 4. Rewrite header with FieldMapping + rewrite_csv_header_with_fieldmapping(file_path, self.mapping) + + if options.compress_using_gzip: + gz_path = gzip_file(file_path) + return gz_path + + return file_path diff --git a/mario/mapping.py b/mario/mapping.py index d7584d2..ca52174 100644 --- a/mario/mapping.py +++ b/mario/mapping.py @@ -1,6 +1,9 @@ from typing import List import pandas as pd +import csv +import os +import tempfile from mario.utils import to_snake_case @@ -33,3 +36,35 @@ def df_to_logical(self, df: pd.DataFrame) -> pd.DataFrame: :return: pandas Dataframe """ return df.rename(columns=self.as_logical) + + +def rewrite_csv_header_with_fieldmapping(local_path, field_mapping): + """ + Replace the first-line header in a CSV using field_mapping, but stream + all remaining lines without loading the file into memory. + """ + # Create a temp file in the same directory (safer for atomic replace) + dir_name = os.path.dirname(local_path) + fd, temp_path = tempfile.mkstemp(dir=dir_name) + os.close(fd) + + with open(local_path, "r", encoding="utf-8", newline="") as src, \ + open(temp_path, "w", encoding="utf-8", newline="") as dst: + + reader = csv.reader(src) + writer = csv.writer(dst, quoting=csv.QUOTE_MINIMAL) + + # --- Read & rewrite only the first line --- + physical_header = next(reader) + logical_header = [ + field_mapping.as_logical.get(col, col) + for col in physical_header + ] + writer.writerow(logical_header) + + # --- Stream the rest unchanged --- + for row in reader: + writer.writerow(row) + + # Atomic replace of original file + os.replace(temp_path, local_path) diff --git a/mario/utils.py b/mario/utils.py index f4a21ea..6d43f6f 100644 --- a/mario/utils.py +++ b/mario/utils.py @@ -18,4 +18,14 @@ def to_snake_case(name: str) -> str: name = name.strip().lower() name = re.sub(r"[^\w]+", "_", name) name = re.sub(r"__+", "_", name).strip("_") - return name \ No newline at end of file + return name + + +def gzip_file(input_path, output_path=None): + import gzip + import shutil + if output_path is None: + output_path = input_path + ".gz" + with open(input_path, "rb") as f_in, gzip.open(output_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + return output_path diff --git a/requirements.txt b/requirements.txt index 0e1dbc9..28f7792 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ tableau-builder==0.18 pypika sqlalchemy openpyxl -pyathena \ No newline at end of file +pyathena +awswrangler \ No newline at end of file diff --git a/setup.py b/setup.py index 28673cd..6992ccf 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,6 @@ extras_require={ 'Airflow': ['apache-airflow-providers-common-sql'], 'Tableau': ['pantab', 'tableauhyperapi', 'tableau-builder==0.18'], - 'Athena': ['pyathena'] + 'Athena': ['pyathena', 'awswrangler'] } ) \ No newline at end of file diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 3c3609a..a531163 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -2,7 +2,7 @@ import pytest -from mario.athena import AthenaConfiguration, AthenaStreamingDataExtractor +from mario.athena import AthenaConfiguration, AthenaDataExtractor from mario.dataset_specification import DatasetSpecification from mario.query_builder import SubsetQueryBuilder from mario.metadata import Metadata, Item @@ -18,7 +18,7 @@ class MockHook: - def __init__(self, extractor: AthenaStreamingDataExtractor): + def __init__(self, extractor: AthenaDataExtractor): self.extractor = extractor def get_conn(self): @@ -57,38 +57,39 @@ def get_test_conf(): return dataset, metadata, config -def test_athena_stream_sql_to_csv(): - # Skip this test if we don't have AWS env vars - if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: - pytest.skip("Skipping Athena test as AWS not configured") - - shutil.rmtree('output/test_athena', ignore_errors=True) - os.makedirs('output/test_athena', exist_ok=True) - - dataset, metadata, cfg = get_test_conf() - cfg.query_builder = SubsetQueryBuilder - cfg.schema = 'demo' - cfg.view = 'student_open_data' - - extractor = AthenaStreamingDataExtractor( - configuration=cfg, - metadata=metadata, - dataset_specification=dataset - ) - - extractor.stream_sql_to_csv( - file_path='output/test_athena/test.csv', - minimise=True, - compress_using_gzip=False, - do_not_modify_source=True - ) - - # Load and test - df = pd.read_csv('output/test_athena/test.csv') - for column in dataset.dimensions: - assert column in df.columns - assert 'Number' in df.columns - assert len(df.columns) == len(dataset.items) +# def test_athena_stream_sql_to_csv(): +# # Skip this test if we don't have AWS env vars +# if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: +# pytest.skip("Skipping Athena test as AWS not configured") +# +# shutil.rmtree('output/test_athena', ignore_errors=True) +# os.makedirs('output/test_athena', exist_ok=True) +# +# dataset, metadata, cfg = get_test_conf() +# cfg.query_builder = SubsetQueryBuilder +# cfg.schema = 'demo' +# cfg.view = 'student_open_data' +# +# extractor = AthenaDataExtractor( +# configuration=cfg, +# metadata=metadata, +# dataset_specification=dataset +# ) +# +# extractor.stream_sql_to_csv( +# file_path='output/test_athena/test.csv', +# minimise=True, +# compress_using_gzip=False, +# do_not_modify_source=True +# ) +# +# # Load and test +# df = pd.read_csv('output/test_athena/test.csv') +# for column in dataset.dimensions: +# assert column in df.columns +# assert 'Number' in df.columns +# assert len(df.columns) == len(dataset.items) +# assert df['Number'].sum() == 28_733_910 def test_athena_count(): @@ -100,14 +101,14 @@ def test_athena_count(): cfg.schema = 'demo' cfg.view = 'student_open_data' - extractor = AthenaStreamingDataExtractor( + extractor = AthenaDataExtractor( configuration=cfg, metadata=metadata, dataset_specification=dataset ) total = extractor.get_total(measure=dataset.measures[0]) - print("total", total) # 28,733,910 + assert total == 28_733_910 def test_athena_save_data_as_csv(): @@ -122,7 +123,7 @@ def test_athena_save_data_as_csv(): cfg.schema = 'demo' cfg.view = 'student_open_data' - extractor = AthenaStreamingDataExtractor( + extractor = AthenaDataExtractor( configuration=cfg, metadata=metadata, dataset_specification=dataset @@ -140,6 +141,8 @@ def test_athena_save_data_as_csv(): for column in dataset.dimensions: assert column in df.columns assert 'Number' in df.columns + for dimension in dataset.dimensions: + assert dimension in df.columns assert len(df.columns) == len(dataset.items) @@ -155,7 +158,7 @@ def test_athena_validate(): cfg.schema = 'demo' cfg.view = 'student_open_data' - extractor = AthenaStreamingDataExtractor( + extractor = AthenaDataExtractor( configuration=cfg, metadata=metadata, dataset_specification=dataset @@ -171,4 +174,37 @@ def test_athena_validate(): validator.validate_data(allow_nulls=False) +def test_athena_save_data_as_csv_with_gzip(): + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") + + shutil.rmtree('output/test_athena', ignore_errors=True) + os.makedirs('output/test_athena', exist_ok=True) + + dataset, metadata, cfg = get_test_conf() + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + extractor.save_data_as_csv( + file_path='output/test_athena/test.csv', + minimise=False, + compress_using_gzip=True, + do_not_modify_source=True + ) + + # Load and test + df = pd.read_csv('output/test_athena/test.csv.gz') + for column in dataset.dimensions: + assert column in df.columns + assert 'Number' in df.columns + for dimension in dataset.dimensions: + assert dimension in df.columns + assert len(df.columns) == len(dataset.items) From 33920f5740947dd3dbdce4c91aaeeac314ccd53e Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 23 Mar 2026 09:43:33 +0000 Subject: [PATCH 13/24] feat: cache uniques for current item under validation as otherwise we tend to call it twice (once for nulls, once for domains) --- mario/validation.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mario/validation.py b/mario/validation.py index 1b35ecb..02e6fbe 100644 --- a/mario/validation.py +++ b/mario/validation.py @@ -425,6 +425,7 @@ def __init__(self, self.schema = configuration.schema self.view = configuration.view self.mapping = FieldMapping(query_format=configuration.query_format, items=dataset_specification.items) + self.cached_column_values = {} def __get_column_name__(self, item: Item): """ Returns the column name for a metadata item""" @@ -484,9 +485,22 @@ def __get_minimum_maximum_values__(self, item:Item): def __get_column_values__(self, item: Item): import pandas as pd column = self.__get_column_name__(item) + + # Get cached values if present + if column in self.cached_column_values: + return self.cached_column_values[column] + + # If the column is not present, clear the cache - we don't want to + # cache all the columns as that uses more memory + self.cached_column_values = {} + + # Load unique values sql = f"SELECT DISTINCT \"{column}\" AS checkfield FROM {self.schema}.{self.view}" df = pd.read_sql(sql, self.connection) values = df['checkfield'].to_list() + + # Cache values so we don't repeat this query + self.cached_column_values[column] = values return values def __get_data_for_hierarchy__(self, name): From d1539286bc68888d7eb5cd9f9fb0c0308edf1a70 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 23 Mar 2026 09:43:56 +0000 Subject: [PATCH 14/24] feat: split out running query from saving a CSV locally --- mario/athena.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/mario/athena.py b/mario/athena.py index cac9376..3bfa075 100644 --- a/mario/athena.py +++ b/mario/athena.py @@ -35,7 +35,7 @@ def __init__(self, configuration: AthenaConfiguration, dataset_specification, me def get_connection(self): """ - PyAthena provides a DBAPI connection compatible with pandas.read_sql() including chunksize. + PyAthena provides a DBAPI connection compatible with pandas.read_sql(). """ cfg = self.configuration @@ -51,7 +51,7 @@ def get_connection(self): work_group=cfg.aws_athena_workgroup, schema_name=cfg.schema, catalog_name=cfg.catalog - ) + ) def get_total(self, measure=None): """ @@ -76,21 +76,15 @@ def get_total(self, measure=None): totals_df = read_sql(totals_query[0], self.get_connection(), params=totals_query[1]) return totals_df.iat[0, 0] - def save_data_as_csv(self, file_path: str, **kwargs): + def run_query(self) -> str: """ - Athena doesn't support streaming, but natively saves CSV files - in S3 as output so we really don't need to do anything else - other than run the query and download the results from S3 - :param file_path: - :param kwargs: - :return: + Runs the SQL query in Athena and returns the + result path in S3 + :return: the S3 path where the result is stored """ import awswrangler as wr import boto3 - # Parse options - options = CsvOptions(**kwargs) - # Build SQL self.__build_query__() sql = self._query[0] @@ -110,10 +104,29 @@ def save_data_as_csv(self, file_path: str, **kwargs): meta = client.get_query_execution(QueryExecutionId=qid) s3_uri = meta["QueryExecution"]["ResultConfiguration"]["OutputLocation"] - # 3. Download raw Athena CSV + return s3_uri + + def save_data_as_csv(self, file_path: str, **kwargs): + """ + Athena doesn't support streaming, but natively saves CSV files + in S3 as output, so we really don't need to do anything else + other than run the query and download the results from S3 + :param file_path: + :param kwargs: + :return: + """ + import awswrangler as wr + + # Parse options + options = CsvOptions(**kwargs) + + # Run query and get output location + s3_uri = self.run_query() + + # Download raw Athena CSV wr.s3.download(path=s3_uri, local_file=file_path) - # 4. Rewrite header with FieldMapping + # Rewrite header with FieldMapping rewrite_csv_header_with_fieldmapping(file_path, self.mapping) if options.compress_using_gzip: From 86ec067127d15902cfabf80c3a47fe19ed1500c5 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Mon, 23 Mar 2026 09:44:14 +0000 Subject: [PATCH 15/24] build: ensure we have boto3 when using Athena --- requirements.txt | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 28f7792..05728c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ pypika sqlalchemy openpyxl pyathena -awswrangler \ No newline at end of file +awswrangler +boto3 \ No newline at end of file diff --git a/setup.py b/setup.py index 6992ccf..a61ce22 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,6 @@ extras_require={ 'Airflow': ['apache-airflow-providers-common-sql'], 'Tableau': ['pantab', 'tableauhyperapi', 'tableau-builder==0.18'], - 'Athena': ['pyathena', 'awswrangler'] + 'Athena': ['pyathena', 'awswrangler', 'boto3'] } ) \ No newline at end of file From 29ee5f3e50e99171f975315a12345633ff8c6b6f Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 25 Mar 2026 09:21:38 +0000 Subject: [PATCH 16/24] fix: map column name to physical name when processing constraints --- mario/query_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mario/query_builder.py b/mario/query_builder.py index cac2a55..6836665 100644 --- a/mario/query_builder.py +++ b/mario/query_builder.py @@ -165,7 +165,7 @@ def create_constraints(self, q): clauses = [] parameters = {} for constraint in self.dataset_specification.constraints: - column = constraint.item + column = self.mapping.as_physical[constraint.item] placeholders = [] for i in range(len(constraint.allowed_values)): parameter_name = column.replace(" ", "_") + str(i) From c2208f69f33b2e1c3fb5e17c21642f024f65ddaf Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 25 Mar 2026 09:22:11 +0000 Subject: [PATCH 17/24] fix: interpolate parameters in query --- mario/athena.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/mario/athena.py b/mario/athena.py index 3bfa075..b70ec36 100644 --- a/mario/athena.py +++ b/mario/athena.py @@ -1,6 +1,9 @@ from pyathena import connect from mario.data_extractor import DataExtractor, Configuration import logging +import datetime +import numbers +import re from mario.mapping import rewrite_csv_header_with_fieldmapping from mario.options import CsvOptions @@ -58,7 +61,7 @@ def get_total(self, measure=None): For totals when streaming data we need to run a totals SQL query separate from the main query and use the results of this :return: the total value of the query - TODO this is a direct copy from StreamingDataExtractor + NOTE this is a direct copy from StreamingDataExtractor """ from pandas import read_sql logger.info("Building totals query") @@ -87,7 +90,8 @@ def run_query(self) -> str: # Build SQL self.__build_query__() - sql = self._query[0] + raw_sql, parameters = self._query + sql = interpolate_athena_sql(raw_sql, parameters) cfg = self.configuration # 1. Run SQL via Wrangler @@ -134,3 +138,56 @@ def save_data_as_csv(self, file_path: str, **kwargs): return gz_path return file_path + + +ATHENA_PARAM_PATTERN = re.compile(r'%\((?P[a-zA-Z0-9_]+)\)s') + + +def athena_escape_string(value: str) -> str: + """ + Escape a Python string for safe Athena SQL literal usage. + Athena uses standard SQL single-quoted strings; internal quotes doubled. + """ + return value.replace("'", "''") + + +def athena_literal(value): + """ + Convert a Python value into an Athena-safe SQL literal. + """ + # None -> NULL + if value is None: + return "NULL" + + # Boolean -> true/false + if isinstance(value, bool): + return "true" if value else "false" + + # Number -> raw + if isinstance(value, numbers.Number): + return str(value) + + # Date / datetime + if isinstance(value, (datetime.date, datetime.datetime)): + return f"'{value.isoformat()}'" + + # List/tuple -> comma-separated list of literals + if isinstance(value, (list, tuple)): + return ", ".join(athena_literal(v) for v in value) + + # Everything else -> string literal + return f"'{athena_escape_string(str(value))}'" + + +def interpolate_athena_sql(sql: str, params: dict): + """ + Replace all %(name)s placeholders using Athena-safe literals. + """ + + def repl(match): + name = match.group("name") + if name not in params: + raise KeyError(f"Missing SQL parameter: {name}") + return athena_literal(params[name]) + + return ATHENA_PARAM_PATTERN.sub(repl, sql) \ No newline at end of file From 36d8c7535eca29502ff05d14a3e65d4f535831f8 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 25 Mar 2026 09:22:35 +0000 Subject: [PATCH 18/24] test: added unit test for parameterised athena query --- test/test_athena_extractor.py | 80 +++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index a531163..70d44c8 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -3,7 +3,7 @@ import pytest from mario.athena import AthenaConfiguration, AthenaDataExtractor -from mario.dataset_specification import DatasetSpecification +from mario.dataset_specification import DatasetSpecification, Constraint from mario.query_builder import SubsetQueryBuilder from mario.metadata import Metadata, Item import os @@ -57,41 +57,6 @@ def get_test_conf(): return dataset, metadata, config -# def test_athena_stream_sql_to_csv(): -# # Skip this test if we don't have AWS env vars -# if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: -# pytest.skip("Skipping Athena test as AWS not configured") -# -# shutil.rmtree('output/test_athena', ignore_errors=True) -# os.makedirs('output/test_athena', exist_ok=True) -# -# dataset, metadata, cfg = get_test_conf() -# cfg.query_builder = SubsetQueryBuilder -# cfg.schema = 'demo' -# cfg.view = 'student_open_data' -# -# extractor = AthenaDataExtractor( -# configuration=cfg, -# metadata=metadata, -# dataset_specification=dataset -# ) -# -# extractor.stream_sql_to_csv( -# file_path='output/test_athena/test.csv', -# minimise=True, -# compress_using_gzip=False, -# do_not_modify_source=True -# ) -# -# # Load and test -# df = pd.read_csv('output/test_athena/test.csv') -# for column in dataset.dimensions: -# assert column in df.columns -# assert 'Number' in df.columns -# assert len(df.columns) == len(dataset.items) -# assert df['Number'].sum() == 28_733_910 - - def test_athena_count(): if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: pytest.skip("Skipping Athena test as AWS not configured") @@ -208,3 +173,46 @@ def test_athena_save_data_as_csv_with_gzip(): assert dimension in df.columns assert len(df.columns) == len(dataset.items) + +def test_athena_with_constraints(): + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") + + shutil.rmtree('output/test_athena_with_constraints', ignore_errors=True) + os.makedirs('output/test_athena_with_constraints', exist_ok=True) + + dataset, metadata, cfg = get_test_conf() + level_of_study = Item() + level_of_study.name = 'Level of study' + level_constraint = Constraint() + level_constraint.item = level_of_study.name + level_constraint.allowed_values = ['Postgraduate (research)', 'Postgraduate (taught)'] + metadata.add_item(level_of_study) + dataset.dimensions.append(level_of_study.name) + dataset.constraints.append(level_constraint) + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + extractor.save_data_as_csv( + file_path='output/test_athena_with_constraints/test.csv', + minimise=False, + compress_using_gzip=False, + do_not_modify_source=True + ) + + # Load and test + df = pd.read_csv('output/test_athena_with_constraints/test.csv') + for column in dataset.dimensions: + assert column in df.columns + assert 'Number' in df.columns + for dimension in dataset.dimensions: + assert dimension in df.columns + assert len(df.columns) == len(dataset.items) + print(df['Level of study'].unique()) From b86524289ddffd3b5e36396466b0edf9bcb19749 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 25 Mar 2026 09:29:51 +0000 Subject: [PATCH 19/24] fix: ensure constraints are also appended to the FieldMapping list even when not present as a dimension --- mario/query_builder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mario/query_builder.py b/mario/query_builder.py index 6836665..1dd3c03 100644 --- a/mario/query_builder.py +++ b/mario/query_builder.py @@ -32,7 +32,11 @@ def __init__(self, self.configuration = configuration self.metadata = metadata self.dataset_specification = dataset_specification - self.mapping = FieldMapping(query_format=configuration.query_format, items=dataset_specification.items) + items = dataset_specification.items + if dataset_specification.constraints: + for constraint in dataset_specification.constraints: + items.append(constraint.item) + self.mapping = FieldMapping(query_format=configuration.query_format, items=items) def create_query(self) -> [str, List[any]]: raise NotImplementedError From 3cb7d818973d8ff777db6e410ab364f245959eed Mon Sep 17 00:00:00 2001 From: scottwilson Date: Wed, 25 Mar 2026 15:46:32 +0000 Subject: [PATCH 20/24] fix: join paths using string values only when splitting files. --- mario/dataset_splitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mario/dataset_splitter.py b/mario/dataset_splitter.py index 29c8b9e..7ff0639 100644 --- a/mario/dataset_splitter.py +++ b/mario/dataset_splitter.py @@ -252,7 +252,7 @@ def split_excel_using_builder(self, file_name: str) -> None: # For each unique value, subset the data and create the output for value in values: logger.info(f"Splitting {file_name} for value {value}") - output_folder = os.path.join(self.output_path, value) + output_folder = os.path.join(self.output_path, str(value)) os.makedirs(output_folder, exist_ok=True) output_path = os.path.join(output_folder, file_name) builder.filepath = output_path From 23e8c060eded2c4283bb22f6210f3044bb76c597 Mon Sep 17 00:00:00 2001 From: scottwilson Date: Thu, 26 Mar 2026 10:59:50 +0000 Subject: [PATCH 21/24] fix: Added test to ensure we handle escaping properly --- test/test_athena_extractor.py | 47 ++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/test/test_athena_extractor.py b/test/test_athena_extractor.py index 70d44c8..248903b 100644 --- a/test/test_athena_extractor.py +++ b/test/test_athena_extractor.py @@ -215,4 +215,49 @@ def test_athena_with_constraints(): for dimension in dataset.dimensions: assert dimension in df.columns assert len(df.columns) == len(dataset.items) - print(df['Level of study'].unique()) + assert len(df['Level of study'].unique()) == 2 + + +def test_athena_with_constraints_and_apostrophes(): + if not AWS_PROFILE or not AWS_ATHENA_RESULTS_DIR or not AWS_REGION: + pytest.skip("Skipping Athena test as AWS not configured") + + shutil.rmtree('output/test_athena_with_constraints_and_apostrophes', ignore_errors=True) + os.makedirs('output/test_athena_with_constraints_and_apostrophes', exist_ok=True) + + dataset, metadata, cfg = get_test_conf() + provider = Item() + provider.name = 'HE Provider' + provider_constraint = Constraint() + provider_constraint.item = provider.name + provider_constraint.allowed_values = ["Queen's University Belfast", "City St George's, University of London"] + metadata.add_item(provider) + dataset.dimensions.append(provider.name) + dataset.constraints.append(provider_constraint) + cfg.query_builder = SubsetQueryBuilder + cfg.schema = 'demo' + cfg.view = 'student_open_data' + + extractor = AthenaDataExtractor( + configuration=cfg, + metadata=metadata, + dataset_specification=dataset + ) + + extractor.save_data_as_csv( + file_path='output/test_athena_with_constraints_and_apostrophes/test.csv', + minimise=False, + compress_using_gzip=False, + do_not_modify_source=True + ) + + # Load and test + df = pd.read_csv('output/test_athena_with_constraints_and_apostrophes/test.csv') + for column in dataset.dimensions: + assert column in df.columns + assert 'Number' in df.columns + for dimension in dataset.dimensions: + assert dimension in df.columns + assert len(df.columns) == len(dataset.items) + assert len(df['HE Provider'].unique()) == 2 + assert "Queen's University Belfast" in df['HE Provider'].unique() \ No newline at end of file From ccad381bd239c8a819511ad4b541d165eb766986 Mon Sep 17 00:00:00 2001 From: "tiffany.cheng" Date: Wed, 6 May 2026 13:24:03 +0100 Subject: [PATCH 22/24] Add encoding to pass the pytest --- mario/dataset_splitter.py | 6 +++--- test/test_dataset_splitter.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mario/dataset_splitter.py b/mario/dataset_splitter.py index 7ff0639..001171a 100644 --- a/mario/dataset_splitter.py +++ b/mario/dataset_splitter.py @@ -123,7 +123,7 @@ def process_batch(self, batch, column_name, file_handles, file_name): if value not in file_handles: os.makedirs(os.path.join(self.output_path, value), exist_ok=True) file_path = self.get_output_path(field_value=value, file=file_name) - file_handles[value] = open(file_path, 'w', newline='') + file_handles[value] = open(file_path, 'w', newline='', encoding='utf-8') writer = csv.DictWriter(file_handles[value], fieldnames=row.keys()) writer.writeheader() writer = csv.DictWriter(file_handles[value], fieldnames=row.keys()) @@ -142,7 +142,7 @@ def split_gzipped_csv(self, file_name: str, batch_size=10000): output_file_name = file_name.rstrip('.gz') - with gzip.open(file_path, 'rt', newline='') as infile: + with gzip.open(file_path, 'rt', newline='', encoding='utf-8') as infile: reader = csv.DictReader(infile) # Process the input file in batches @@ -172,7 +172,7 @@ def split_csv(self, file_name: str, batch_size=10000, compression=None): file_path = os.path.join(self.source_path, file_name) # Open the input CSV file for reading - with open(file_path, 'r') as infile: + with open(file_path, 'r', encoding='utf-8') as infile: reader = csv.DictReader(infile) # Process the input file in batches diff --git a/test/test_dataset_splitter.py b/test/test_dataset_splitter.py index 3a4873e..170dec0 100644 --- a/test/test_dataset_splitter.py +++ b/test/test_dataset_splitter.py @@ -9,7 +9,7 @@ def count_rows_in_csv(file_path): - with open(file_path, 'r') as file: + with open(file_path, 'r', encoding='utf-8') as file: reader = csv.reader(file) row_count = sum(1 for row in reader) - 1 # Subtract 1 to exclude the header row return row_count From 94737d330b3d3aa61332f275cc0028286020b47a Mon Sep 17 00:00:00 2001 From: "tiffany.cheng" Date: Fri, 8 May 2026 10:24:27 +0100 Subject: [PATCH 23/24] Build: release version 0.62 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 499411d..310aff7 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='mario-pipeline-tools', - version='0.61', + version='0.62', packages=['mario'], url='https://github.com/JiscDACT/mario', license='all rights reserved', From 404b713c745501784d873b2d42eda84924024a60 Mon Sep 17 00:00:00 2001 From: "tiffany.cheng" Date: Fri, 8 May 2026 11:30:42 +0100 Subject: [PATCH 24/24] Fix flake8 module import issue --- mario/data_extractor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mario/data_extractor.py b/mario/data_extractor.py index 369fed4..8410669 100644 --- a/mario/data_extractor.py +++ b/mario/data_extractor.py @@ -170,7 +170,7 @@ def __drop_row_numbers__(self): if self._data is None: self.__load__() if 'row_number' in self._data.columns: - self._data = self._data .drop(columns=['row_number']) + self._data = self._data.drop(columns=['row_number']) def save_query(self, file_path: str, formatted: bool = False): """ @@ -228,6 +228,7 @@ class HyperFile(DataExtractor): Wrapper for a HyperFile as a data extractor - use when no data needs to be extracted, and we just want to treat a hyper as a hyper with no conversion to/from dataframe """ + def __init__(self, configuration: Configuration, dataset_specification: DatasetSpecification, @@ -288,7 +289,7 @@ def save_data_as_hyper(self, file_path: str, **kwargs): ) save_hyper_as_hyper(hyper_file=self.configuration.file_path, file_path=file_path, **kwargs) - def save_data_as_csv(self,file_path: str, **kwargs): + def save_data_as_csv(self, file_path: str, **kwargs): from mario.hyper_utils import save_hyper_as_csv options = CsvOptions(**kwargs) if options.minimise: @@ -308,6 +309,7 @@ class StreamingDataExtractor(DataExtractor): supporting streaming data from SQL to output formats without holding any data in memory using a data frame """ + def __init__(self, configuration: Configuration, dataset_specification: DatasetSpecification, @@ -410,6 +412,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): frame_to_hyper(df, database=file_path, table=table_name, table_mode='a') def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, **kwargs) -> int: + from mario.query_builder import get_formatted_query options = CsvOptions(**kwargs) if options.compress_using_gzip: compression_options = dict(method='gzip') @@ -545,6 +548,7 @@ class PartitioningExtractor(StreamingDataExtractor): A data extractor that loads from SQL in batches using a specified constraint to partition by """ + def __init__(self, configuration: Configuration, dataset_specification: DatasetSpecification, @@ -671,5 +675,3 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): else: logger.info(f"Saving {options.chunk_size} rows to file") frame_to_hyper(df, database=file_path, table=table_name, table_mode='a') - -