Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
88e210b
feat: Added a FieldMapping class to handle transformations for physic…
scottbw Mar 9, 2026
c4fbc46
feat: Added FieldMapping to Data Extractor and Query Builder
scottbw Mar 9, 2026
6d476a7
feat: Added Athena data extractor
scottbw Mar 9, 2026
c8b0a35
feat: reverse physical->logical mapping when writing data
scottbw Mar 11, 2026
754a116
test: added testing for column spec
scottbw Mar 11, 2026
6f3385c
test: added testing for totals
scottbw Mar 12, 2026
0c1d143
fix: dropped some of the defaults and skip tests if we don't have acc…
scottbw Mar 12, 2026
00202d0
fix: added FieldMapping to SqlValidator, and use hooks where availabl…
scottbw Mar 18, 2026
55743b3
test: Added validation test
scottbw Mar 18, 2026
04c4abc
doc: added link to hook docs for Airflow+Athena
scottbw Mar 18, 2026
98aebd7
test: skip Athena tests if any env vars aren't set
scottbw Mar 19, 2026
9e211a1
feat: refactored Athena as a bulk rather than streaming extractor, as…
scottbw Mar 20, 2026
33920f5
feat: cache uniques for current item under validation as otherwise we…
scottbw Mar 23, 2026
d153928
feat: split out running query from saving a CSV locally
scottbw Mar 23, 2026
86ec067
build: ensure we have boto3 when using Athena
scottbw Mar 23, 2026
29ee5f3
fix: map column name to physical name when processing constraints
scottbw Mar 25, 2026
c2208f6
fix: interpolate parameters in query
scottbw Mar 25, 2026
36d8c75
test: added unit test for parameterised athena query
scottbw Mar 25, 2026
b865242
fix: ensure constraints are also appended to the FieldMapping list ev…
scottbw Mar 25, 2026
3cb7d81
fix: join paths using string values only when splitting files.
scottbw Mar 25, 2026
23e8c06
fix: Added test to ensure we handle escaping properly
scottbw Mar 26, 2026
7f822be
Merge pull request #11 from JiscDACT/main
scottbw Mar 26, 2026
ccad381
Add encoding to pass the pytest
TiffanyCheng27 May 6, 2026
94737d3
Build: release version 0.62
TiffanyCheng27 May 8, 2026
8a909db
Merge branch 'main' into athena
TiffanyCheng27 May 8, 2026
404b713
Fix flake8 module import issue
TiffanyCheng27 May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions mario/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from pyathena import connect
from mario.data_extractor import DataExtractor, Configuration
import logging
import datetime
import numbers
import re

from mario.mapping import rewrite_csv_header_with_fieldmapping
from mario.options import CsvOptions
from mario.utils import gzip_file

logger = logging.getLogger(__name__)


class AthenaConfiguration(Configuration):
"""
Extended configuration
"""

def __init__(self):
super().__init__()
self.aws_s3_staging_dir = None
self.aws_region_name = None
self.aws_athena_workgroup = 'primary'
self.catalog = 'awsdatacatalog'
self.query_format = 'snake_case'


class AthenaDataExtractor(DataExtractor):
"""
Streaming extractor using Athena + PyAthena.
Extends StreamingDataExtractor; only get_connection() is Athena-specific.
"""

def __init__(self, configuration: AthenaConfiguration, dataset_specification, metadata):
super().__init__(configuration, dataset_specification, metadata)
self.configuration = configuration

def get_connection(self):
"""
PyAthena provides a DBAPI connection compatible with pandas.read_sql().
"""
cfg = self.configuration

if cfg.hook:
# If the user provided a hook, delegate to it
# See: https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/_api/airflow/providers/amazon/aws/hooks/athena_sql/index.html
return cfg.hook.get_conn()

# Expect these in configuration:
return connect(
s3_staging_dir=cfg.aws_s3_staging_dir,
region_name=cfg.aws_region_name,
work_group=cfg.aws_athena_workgroup,
schema_name=cfg.schema,
catalog_name=cfg.catalog
)

def get_total(self, measure=None):
"""
For totals when streaming data we need to run a totals SQL query separate
from the main query and use the results of this
:return: the total value of the query
NOTE this is a direct copy from StreamingDataExtractor
"""
from pandas import read_sql
logger.info("Building totals query")
measure = self.__get_measure__(measure)
if self.configuration.query_builder is not None:
from mario.query_builder import QueryBuilder
query_builder: QueryBuilder = self.configuration.query_builder(
configuration=self.configuration,
metadata=self.metadata,
dataset_specification=self.dataset_specification)
totals_query = query_builder.create_totals_query(measure=measure)
else:
raise NotImplementedError

totals_df = read_sql(totals_query[0], self.get_connection(), params=totals_query[1])
return totals_df.iat[0, 0]

def run_query(self) -> str:
"""
Runs the SQL query in Athena and returns the
result path in S3
:return: the S3 path where the result is stored
"""
import awswrangler as wr
import boto3

# Build SQL
self.__build_query__()
raw_sql, parameters = self._query
sql = interpolate_athena_sql(raw_sql, parameters)
cfg = self.configuration

# 1. Run SQL via Wrangler
qid = wr.athena.start_query_execution(
sql=sql,
database=cfg.schema,
workgroup=cfg.aws_athena_workgroup,
s3_output=cfg.aws_s3_staging_dir,
)
wr.athena.wait_query(qid)

# 2. Get S3 CSV result path
client = boto3.client("athena")
meta = client.get_query_execution(QueryExecutionId=qid)
s3_uri = meta["QueryExecution"]["ResultConfiguration"]["OutputLocation"]

return s3_uri

def save_data_as_csv(self, file_path: str, **kwargs):
"""
Athena doesn't support streaming, but natively saves CSV files
in S3 as output, so we really don't need to do anything else
other than run the query and download the results from S3
:param file_path:
:param kwargs:
:return:
"""
import awswrangler as wr

# Parse options
options = CsvOptions(**kwargs)

# Run query and get output location
s3_uri = self.run_query()

# Download raw Athena CSV
wr.s3.download(path=s3_uri, local_file=file_path)

# Rewrite header with FieldMapping
rewrite_csv_header_with_fieldmapping(file_path, self.mapping)

if options.compress_using_gzip:
gz_path = gzip_file(file_path)
return gz_path

return file_path


ATHENA_PARAM_PATTERN = re.compile(r'%\((?P<name>[a-zA-Z0-9_]+)\)s')


def athena_escape_string(value: str) -> str:
"""
Escape a Python string for safe Athena SQL literal usage.
Athena uses standard SQL single-quoted strings; internal quotes doubled.
"""
return value.replace("'", "''")


def athena_literal(value):
"""
Convert a Python value into an Athena-safe SQL literal.
"""
# None -> NULL
if value is None:
return "NULL"

# Boolean -> true/false
if isinstance(value, bool):
return "true" if value else "false"

# Number -> raw
if isinstance(value, numbers.Number):
return str(value)

# Date / datetime
if isinstance(value, (datetime.date, datetime.datetime)):
return f"'{value.isoformat()}'"

# List/tuple -> comma-separated list of literals
if isinstance(value, (list, tuple)):
return ", ".join(athena_literal(v) for v in value)

# Everything else -> string literal
return f"'{athena_escape_string(str(value))}'"


def interpolate_athena_sql(sql: str, params: dict):
"""
Replace all %(name)s placeholders using Athena-safe literals.
"""

def repl(match):
name = match.group("name")
if name not in params:
raise KeyError(f"Missing SQL parameter: {name}")
return athena_literal(params[name])

return ATHENA_PARAM_PATTERN.sub(repl, sql)
39 changes: 21 additions & 18 deletions mario/data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mario.dataset_specification import DatasetSpecification
from mario.metadata import Metadata
from mario.options import CsvOptions, HyperOptions
from mario.mapping import FieldMapping

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +29,8 @@ def __init__(self,
schema: str = None,
file_path: str = None,
query_builder=None,
user: str = None
user: str = None,
query_format=None
):
self.connection_string = connection_string
self.hook = hook
Expand All @@ -37,6 +39,7 @@ def __init__(self,
self.file_path = file_path
self.query_builder = query_builder
self.user = user
self.query_format = query_format


class DataExtractor:
Expand All @@ -51,6 +54,10 @@ def __init__(self,
self._data = None
self._query = None
self._total = 0
self.mapping = FieldMapping(
query_format=configuration.query_format,
items=dataset_specification.items
)

def __load__(self):
if self.configuration is not None:
Expand All @@ -76,7 +83,9 @@ def __load_from_sql__(self):
logger.info("Executing query")
from sqlalchemy import create_engine
engine = create_engine(self.configuration.connection_string)
self._data = pd.read_sql(sql=self._query[0], con=engine.connect(), params=self._query[1])
df = pd.read_sql(sql=self._query[0], con=engine.connect(), params=self._query[1])
df = self.mapping.df_to_logical(df)
self._data = df

def __build_query__(self):
logger.info("Building query")
Expand Down Expand Up @@ -105,22 +114,12 @@ def __load_from_hyper__(self):
table=table
)

def __get_column_name__(self, item: str):
""" Returns the column name for a metadata item"""
meta = self.metadata.get_metadata(item)
if meta.get_property('output_name') is not None:
return meta.get_property('output_name')
elif meta.get_property('physical_column_name') is not None:
return meta.get_property('physical_column_name')
else:
return meta.name

def __minimise_data__(self):
""" Minimise data so we only keep the columns in the spec """
columns_to_keep = []
for item in self.dataset_specification.items:
if self.metadata.get_metadata(item) and not self.metadata.get_metadata(item).get_property('formula'):
columns_to_keep.append(self.__get_column_name__(item))
columns_to_keep.append(item)
self._data = self._data[columns_to_keep]

def __get_measure__(self, measure=None):
Expand Down Expand Up @@ -171,7 +170,7 @@ def __drop_row_numbers__(self):
if self._data is None:
self.__load__()
if 'row_number' in self._data.columns:
self._data = self._data .drop(columns=['row_number'])
self._data = self._data.drop(columns=['row_number'])

def save_query(self, file_path: str, formatted: bool = False):
"""
Expand Down Expand Up @@ -229,6 +228,7 @@ class HyperFile(DataExtractor):
Wrapper for a HyperFile as a data extractor - use when no data needs to be extracted,
and we just want to treat a hyper as a hyper with no conversion to/from dataframe
"""

def __init__(self,
configuration: Configuration,
dataset_specification: DatasetSpecification,
Expand Down Expand Up @@ -289,7 +289,7 @@ def save_data_as_hyper(self, file_path: str, **kwargs):
)
save_hyper_as_hyper(hyper_file=self.configuration.file_path, file_path=file_path, **kwargs)

def save_data_as_csv(self,file_path: str, **kwargs):
def save_data_as_csv(self, file_path: str, **kwargs):
from mario.hyper_utils import save_hyper_as_csv
options = CsvOptions(**kwargs)
if options.minimise:
Expand All @@ -309,6 +309,7 @@ class StreamingDataExtractor(DataExtractor):
supporting streaming data from SQL to output formats without
holding any data in memory using a data frame
"""

def __init__(self,
configuration: Configuration,
dataset_specification: DatasetSpecification,
Expand Down Expand Up @@ -396,6 +397,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):
table_name = TableName(options.schema, options.table)
row_counter = 0
for df in pd.read_sql(self._query[0], connection, chunksize=options.chunk_size):
df = self.mapping.df_to_logical(df)
if options.validate or options.minimise or options.include_row_numbers:
self._data = df
if options.validate:
Expand All @@ -410,6 +412,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):
frame_to_hyper(df, database=file_path, table=table_name, table_mode='a')

def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, **kwargs) -> int:
from mario.query_builder import get_formatted_query
options = CsvOptions(**kwargs)
if options.compress_using_gzip:
compression_options = dict(method='gzip')
Expand All @@ -420,7 +423,8 @@ def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, *
mode = 'w'
header = True

for df in pd.read_sql(sql=query[0], params=query[1], con=connection, chunksize=options.chunk_size):
for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size):
df = self.mapping.df_to_logical(df)
if options.validate or options.minimise:
self._data = df
if options.validate:
Expand Down Expand Up @@ -544,6 +548,7 @@ class PartitioningExtractor(StreamingDataExtractor):
A data extractor that loads from SQL in batches using a specified constraint
to partition by
"""

def __init__(self,
configuration: Configuration,
dataset_specification: DatasetSpecification,
Expand Down Expand Up @@ -670,5 +675,3 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):
else:
logger.info(f"Saving {options.chunk_size} rows to file")
frame_to_hyper(df, database=file_path, table=table_name, table_mode='a')


6 changes: 3 additions & 3 deletions mario/dataset_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def process_batch(self, batch, column_name, file_handles, file_name):
if value not in file_handles:
os.makedirs(os.path.join(self.output_path, value), exist_ok=True)
file_path = self.get_output_path(field_value=value, file=file_name)
file_handles[value] = open(file_path, 'w', newline='')
file_handles[value] = open(file_path, 'w', newline='', encoding='utf-8')
writer = csv.DictWriter(file_handles[value], fieldnames=row.keys())
writer.writeheader()
writer = csv.DictWriter(file_handles[value], fieldnames=row.keys())
Expand All @@ -142,7 +142,7 @@ def split_gzipped_csv(self, file_name: str, batch_size=10000):

output_file_name = file_name.rstrip('.gz')

with gzip.open(file_path, 'rt', newline='') as infile:
with gzip.open(file_path, 'rt', newline='', encoding='utf-8') as infile:
reader = csv.DictReader(infile)

# Process the input file in batches
Expand Down Expand Up @@ -172,7 +172,7 @@ def split_csv(self, file_name: str, batch_size=10000, compression=None):
file_path = os.path.join(self.source_path, file_name)

# Open the input CSV file for reading
with open(file_path, 'r') as infile:
with open(file_path, 'r', encoding='utf-8') as infile:
reader = csv.DictReader(infile)

# Process the input file in batches
Expand Down
Loading
Loading