diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ebd629..7da6c8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## v0.5.2 (2026-02-02) + +### Refactor + +- allow passing of custom loggers into pipeline objects +- ensure traceback in broad exceptions +- improve the logging around dve processing errors and align reporting to module name rather than legacy name +- add sense check for text based file (#32) + +## v0.5.1 (2026-01-28) + +### Fix + +- deal with pathing assumption that file had been moved to processed_file_path during file transformation + ## v0.5.0 (2026-01-16) ### Feat diff --git a/Makefile b/Makefile index cfad520..7684514 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ behave: ${activate} behave pytest: - ${activate} pytest tests/ + ${activate} pytest -c pytest-dev.ini all-tests: pytest behave diff --git a/docs/README.md b/docs/README.md index 7ab5d92..fc0de4a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -165,8 +165,8 @@ for entity in data_contract_config.schemas: # Data contract step here data_contract = SparkDataContract(spark_session=spark) -entities, validation_messages, success = data_contract.apply_data_contract( - entities, data_contract_config +entities, feedback_errors_uri, success = data_contract.apply_data_contract( + entities, None, data_contract_config ) ``` diff --git a/pyproject.toml b/pyproject.toml index 52263a9..ca0c98f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nhs_dve" -version = "0.5.0" +version = "0.5.2" description = "`nhs data validation engine` is a framework used to validate data" authors = ["NHS England "] readme = "README.md" diff --git a/pytest-dev.ini b/pytest-dev.ini new file mode 100644 index 0000000..11c72fa --- /dev/null +++ b/pytest-dev.ini @@ -0,0 +1,3 @@ +[pytest] +log_cli = true +log_cli_level = INFO diff --git a/tests/test_error_reporting/__init__.py b/src/dve/common/__init__.py similarity index 100% rename from tests/test_error_reporting/__init__.py rename to src/dve/common/__init__.py diff --git a/src/dve/common/error_utils.py b/src/dve/common/error_utils.py new file mode 100644 index 0000000..8dcd465 --- /dev/null +++ b/src/dve/common/error_utils.py @@ -0,0 +1,187 @@ +"""Utilities to support reporting""" + +import datetime as dt +import json +import logging +from collections.abc import Iterable +from itertools import chain +from multiprocessing import Queue +from threading import Thread +from typing import Optional, Union + +import dve.parser.file_handling as fh +from dve.core_engine.exceptions import CriticalProcessingError +from dve.core_engine.loggers import get_logger +from dve.core_engine.message import UserMessage +from dve.core_engine.type_hints import URI, DVEStage, Messages + + +def get_feedback_errors_uri(working_folder: URI, step_name: DVEStage) -> URI: + """Determine the location of json lines file containing all errors generated in a step""" + return fh.joinuri(working_folder, "errors", f"{step_name}_errors.jsonl") + + +def get_processing_errors_uri(working_folder: URI) -> URI: + """Determine the location of json lines file containing all processing + errors generated from DVE run""" + return fh.joinuri(working_folder, "errors", "processing_errors", "processing_errors.jsonl") + + +def dump_feedback_errors( + working_folder: URI, + step_name: DVEStage, + messages: Messages, + key_fields: Optional[dict[str, list[str]]] = None, +) -> URI: + """Write out captured feedback error messages.""" + if not working_folder: + raise AttributeError("processed files path not passed") + + if not key_fields: + key_fields = {} + + error_file = get_feedback_errors_uri(working_folder, step_name) + processed = [] + + for message in messages: + if message.original_entity is not None: + primary_keys = key_fields.get(message.original_entity, []) + elif message.entity is not None: + primary_keys = key_fields.get(message.entity, []) + else: + primary_keys = [] + + error = message.to_dict( + key_field=primary_keys, + value_separator=" -- ", + max_number_of_values=10, + record_converter=None, + ) + error["Key"] = conditional_cast(error["Key"], primary_keys, value_separator=" -- ") + processed.append(error) + + with fh.open_stream(error_file, "a") as f: + f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n") + return error_file + + +def dump_processing_errors( + working_folder: URI, step_name: str, errors: list[CriticalProcessingError] +): + """Write out critical processing errors""" + if not working_folder: + raise AttributeError("processed files path not passed") + if not step_name: + raise AttributeError("step name not passed") + if not errors: + raise AttributeError("errors list not passed") + + error_file: URI = fh.joinuri(working_folder, "processing_errors", "processing_errors.json") + processed = [] + + for error in errors: + processed.append( + { + "step_name": step_name, + "error_location": "processing", + "error_level": "integrity", + "error_message": error.error_message, + "error_traceback": error.messages, + } + ) + + with fh.open_stream(error_file, "a") as f: + f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n") + + return error_file + + +def load_feedback_messages(feedback_messages_uri: URI) -> Iterable[UserMessage]: + """Load user messages from jsonl file""" + if not fh.get_resource_exists(feedback_messages_uri): + return + with fh.open_stream(feedback_messages_uri) as errs: + yield from (UserMessage(**json.loads(err)) for err in errs.readlines()) + + +def load_all_error_messages(error_directory_uri: URI) -> Iterable[UserMessage]: + "Load user messages from all jsonl files" + return chain.from_iterable( + [ + load_feedback_messages(err_file) + for err_file, _ in fh.iter_prefix(error_directory_uri) + if err_file.endswith(".jsonl") + ] + ) + + +class BackgroundMessageWriter: + """Controls batch writes to error jsonl files""" + + def __init__( + self, + working_directory: URI, + dve_stage: DVEStage, + key_fields: Optional[dict[str, list[str]]] = None, + logger: Optional[logging.Logger] = None, + ): + self._working_directory = working_directory + self._dve_stage = dve_stage + self._feedback_message_uri = get_feedback_errors_uri( + self._working_directory, self._dve_stage + ) + self._key_fields = key_fields + self.logger = logger or get_logger(type(self).__name__) + self._write_thread: Optional[Thread] = None + self._queue: Queue = Queue() + + @property + def write_queue(self) -> Queue: # type: ignore + """Queue for storing batches of messages to be written""" + return self._queue + + @property + def write_thread(self) -> Thread: # type: ignore + """Thread to write batches of messages to jsonl file""" + if not self._write_thread: + self._write_thread = Thread(target=self._write_process_wrapper) + return self._write_thread + + def _write_process_wrapper(self): + """Wrapper for dump feedback errors to run in background process""" + while True: + if msgs := self.write_queue.get(): + dump_feedback_errors( + self._working_directory, self._dve_stage, msgs, self._key_fields + ) + else: + break + + def __enter__(self) -> "BackgroundMessageWriter": + self.write_thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + self.logger.exception( + "Issue occured during background write process:", + exc_info=(exc_type, exc_value, traceback), + ) + self.write_queue.put(None) + self.write_thread.join() + + +def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]: + """Determines what to do with a value coming back from the error list""" + if isinstance(value, list): + casts = [ + conditional_cast(val, primary_keys, value_separator) for val in value + ] # type: ignore + return value_separator.join( + [f"{pk}: {id}" if pk else "" for pk, id in zip(primary_keys, casts)] + ) + if isinstance(value, dt.date): + return value.isoformat() + if isinstance(value, dict): + return "" + return str(value) diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index bed2e17..9d6abaa 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -17,13 +17,8 @@ from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful from dve.core_engine.loggers import get_logger from dve.core_engine.models import SubmissionInfo -from dve.core_engine.type_hints import ( - URI, - EntityLocations, - EntityName, - EntityParquetLocations, - Messages, -) +from dve.core_engine.type_hints import URI, EntityLocations, EntityName, EntityParquetLocations +from dve.parser.file_handling.service import get_parent, joinuri class BaseBackend(Generic[EntityType], ABC): @@ -148,11 +143,12 @@ def convert_entities_to_spark( def apply( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, URI, StageSuccessful]: """Apply the data contract and the rules, returning the entities and all generated messages. @@ -160,9 +156,11 @@ def apply( reference_data = self.load_reference_data( rule_metadata.reference_data_config, submission_info ) - entities, messages, successful = self.contract.apply(entity_locations, contract_metadata) + entities, dc_feedback_errors_uri, successful, processing_errors_uri = self.contract.apply( + working_dir, entity_locations, contract_metadata + ) if not successful: - return entities, messages, successful + return entities, get_parent(processing_errors_uri), successful for entity_name, entity in entities.items(): entities[entity_name] = self.step_implementations.add_row_id(entity) @@ -170,43 +168,46 @@ def apply( # TODO: Handle entity manager creation errors. entity_manager = EntityManager(entities, reference_data) # TODO: Add stage success to 'apply_rules' - rule_messages = self.step_implementations.apply_rules(entity_manager, rule_metadata) - messages.extend(rule_messages) + # TODO: In case of large errors in business rules, write messages to jsonl file + # TODO: and return uri to errors + _ = self.step_implementations.apply_rules(working_dir, entity_manager, rule_metadata) for entity_name, entity in entity_manager.entities.items(): entity_manager.entities[entity_name] = self.step_implementations.drop_row_id(entity) - return entity_manager.entities, messages, True + return entity_manager.entities, get_parent(dc_feedback_errors_uri), True def process( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, - cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[MutableMapping[EntityName, URI], Messages]: + ) -> tuple[MutableMapping[EntityName, URI], URI]: """Apply the data contract and the rules, write the entities out to parquet and returning the entity locations and all generated messages. """ - entities, messages, successful = self.apply( - entity_locations, contract_metadata, rule_metadata, submission_info + entities, feedback_errors_uri, successful = self.apply( + working_dir, entity_locations, contract_metadata, rule_metadata, submission_info ) if successful: - parquet_locations = self.write_entities_to_parquet(entities, cache_prefix) + parquet_locations = self.write_entities_to_parquet( + entities, joinuri(working_dir, "outputs") + ) else: parquet_locations = {} - return parquet_locations, messages + return parquet_locations, get_parent(feedback_errors_uri) def process_legacy( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, - cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[MutableMapping[EntityName, DataFrame], Messages]: + ) -> tuple[MutableMapping[EntityName, DataFrame], URI]: """Apply the data contract and the rules, create Spark `DataFrame`s from the entities and return the Spark entities and all generated messages. @@ -221,17 +222,19 @@ def process_legacy( category=DeprecationWarning, ) - entities, messages, successful = self.apply( - entity_locations, contract_metadata, rule_metadata, submission_info + entities, errors_uri, successful = self.apply( + working_dir, entity_locations, contract_metadata, rule_metadata, submission_info ) if not successful: - return {}, messages + return {}, errors_uri if self.__entity_type__ == DataFrame: - return entities, messages # type: ignore + return entities, errors_uri # type: ignore return ( - self.convert_entities_to_spark(entities, cache_prefix, _emit_deprecation_warning=False), - messages, + self.convert_entities_to_spark( + entities, joinuri(working_dir, "outputs"), _emit_deprecation_warning=False + ), + errors_uri, ) diff --git a/src/dve/core_engine/backends/base/contract.py b/src/dve/core_engine/backends/base/contract.py index 338bd9f..89d29c0 100644 --- a/src/dve/core_engine/backends/base/contract.py +++ b/src/dve/core_engine/backends/base/contract.py @@ -9,6 +9,11 @@ from pydantic import BaseModel from typing_extensions import Protocol +from dve.common.error_utils import ( + dump_processing_errors, + get_feedback_errors_uri, + get_processing_errors_uri, +) from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.base.reader import BaseFileReader from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport, render_error @@ -16,6 +21,7 @@ from dve.core_engine.backends.readers import get_reader from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful from dve.core_engine.backends.utilities import dedup_messages, stringify_model +from dve.core_engine.exceptions import CriticalProcessingError from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import ( @@ -360,8 +366,13 @@ def read_raw_entities( @abstractmethod def apply_data_contract( - self, entities: Entities, contract_metadata: DataContractMetadata - ) -> tuple[Entities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: Entities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[Entities, URI, StageSuccessful]: """Apply the data contract to the raw entities, returning the validated entities and any messages. @@ -371,21 +382,36 @@ def apply_data_contract( raise NotImplementedError() def apply( - self, entity_locations: EntityLocations, contract_metadata: DataContractMetadata - ) -> tuple[Entities, Messages, StageSuccessful]: + self, + working_dir: URI, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[Entities, URI, StageSuccessful, URI]: """Read the entities from the provided locations according to the data contract, and return the validated entities and any messages. """ + feedback_errors_uri = get_feedback_errors_uri(working_dir, "data_contract") + processing_errors_uri = get_processing_errors_uri(working_dir) entities, messages, successful = self.read_raw_entities(entity_locations, contract_metadata) if not successful: - return {}, messages, successful + dump_processing_errors( + working_dir, + "data_contract", + [ + CriticalProcessingError( + "Issue occurred while reading raw entities", + [msg.error_message for msg in messages], + ) + ], + ) + return {}, feedback_errors_uri, successful, processing_errors_uri try: - entities, contract_messages, successful = self.apply_data_contract( - entities, contract_metadata + entities, feedback_errors_uri, successful = self.apply_data_contract( + working_dir, entities, entity_locations, contract_metadata, key_fields ) - messages.extend(contract_messages) except Exception as err: # pylint: disable=broad-except successful = False new_messages = render_error( @@ -393,13 +419,22 @@ def apply( "data contract", self.logger, ) - messages.extend(new_messages) + dump_processing_errors( + working_dir, + "data_contract", + [ + CriticalProcessingError( + "Issue occurred while applying data_contract", + [msg.error_message for msg in new_messages], + ) + ], + ) if contract_metadata.cache_originals: for entity_name in list(entities): entities[f"Original{entity_name}"] = entities[entity_name] - return entities, messages, successful + return entities, feedback_errors_uri, successful, processing_errors_uri def read_parquet(self, path: URI, **kwargs) -> EntityType: """Method to read parquet files from stringified parquet output diff --git a/src/dve/core_engine/backends/base/reader.py b/src/dve/core_engine/backends/base/reader.py index 9862e7e..54abaa9 100644 --- a/src/dve/core_engine/backends/base/reader.py +++ b/src/dve/core_engine/backends/base/reader.py @@ -8,9 +8,11 @@ from pydantic import BaseModel from typing_extensions import Protocol -from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport +from dve.core_engine.backends.exceptions import MessageBearingError, ReaderLacksEntityTypeSupport from dve.core_engine.backends.types import EntityName, EntityType +from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, ArbitraryFunction, WrapDecorator +from dve.parser.file_handling.service import open_stream T = TypeVar("T") ET_co = TypeVar("ET_co", covariant=True) @@ -116,6 +118,8 @@ def read_to_entity_type( if entity_name == Iterator[dict[str, Any]]: return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore + self.raise_if_not_sensible_file(resource, entity_name) + try: reader_func = self.__read_methods__[entity_type] except KeyError as err: @@ -137,3 +141,36 @@ def write_parquet( """ raise NotImplementedError(f"write_parquet not implemented in {self.__class__}") + + @staticmethod + def _check_likely_text_file(resource: URI) -> bool: + """Quick sense check of file to see if it looks like text + - not 100% full proof, but hopefully enough to weed out most + non-text files""" + with open_stream(resource, "rb") as fle: + start_chunk = fle.read(4096) + # check for BOM character - utf-16 can contain NULL bytes + if start_chunk.startswith((b"\xff\xfe", b"\xfe\xff")): + return True + # if null byte in - unlikely text + if b"\x00" in start_chunk: + return False + return True + + def raise_if_not_sensible_file(self, resource: URI, entity_name: str): + """Sense check that the file is a text file. Raise error if doesn't + appear to be the case.""" + if not self._check_likely_text_file(resource): + raise MessageBearingError( + "The submitted file doesn't appear to be text", + messages=[ + FeedbackMessage( + entity=entity_name, + record=None, + failure_type="submission", + error_location="Whole File", + error_code="MalformedFile", + error_message="The resource doesn't seem to be a valid text file", + ) + ], + ) diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py index 043f826..0b7b385 100644 --- a/src/dve/core_engine/backends/base/rules.py +++ b/src/dve/core_engine/backends/base/rules.py @@ -9,6 +9,12 @@ from typing_extensions import Literal, Protocol, get_type_hints +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_feedback_errors, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.exceptions import render_error from dve.core_engine.backends.metadata.rules import ( @@ -37,6 +43,7 @@ TableUnion, ) from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful +from dve.core_engine.exceptions import CriticalProcessingError from dve.core_engine.loggers import get_logger from dve.core_engine.type_hints import URI, EntityName, Messages, TemplateVariables @@ -188,7 +195,7 @@ def evaluate(self, entities, *, config: AbstractStep) -> tuple[Messages, StageSu if success: success = False msg = f"Critical failure in rule {self._step_metadata_to_location(config)}" - self.logger.error(msg) + self.logger.exception(msg) self.logger.error(str(message)) return messages, success @@ -343,9 +350,14 @@ def notify(self, entities: Entities, *, config: Notification) -> Messages: """ + # pylint: disable=R0912,R0914 def apply_sync_filters( - self, entities: Entities, *filters: DeferredFilter - ) -> tuple[Messages, StageSuccessful]: + self, + working_directory: URI, + entities: Entities, + *filters: DeferredFilter, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[URI, StageSuccessful]: """Apply the synchronised filters, emitting appropriate error messages for any records which do not meet the conditions. @@ -355,108 +367,191 @@ def apply_sync_filters( """ filters_by_entity: dict[EntityName, list[DeferredFilter]] = defaultdict(list) + feedback_errors_uri = get_feedback_errors_uri(working_directory, "business_rules") for rule in filters: filters_by_entity[rule.entity_name].append(rule) - messages: Messages = [] - for entity_name, filter_rules in filters_by_entity.items(): - entity = entities[entity_name] - - filter_column_names: list[str] = [] - unmodified_entities = {entity_name: entity} - modified_entities = {entity_name: entity} - - for rule in filter_rules: - if rule.reporting.emit == "record_failure": - column_name = f"filter_{uuid4().hex}" - filter_column_names.append(column_name) - temp_messages, success = self.evaluate( - modified_entities, - config=ColumnAddition( - entity_name=entity_name, - column_name=column_name, - expression=rule.expression, - parent=rule.parent, - ), - ) - messages.extend(temp_messages) - if not success: - return messages, False - - temp_messages, success = self.evaluate( - modified_entities, - config=Notification( - entity_name=entity_name, - expression=f"NOT {column_name}", - excluded_columns=filter_column_names, - reporting=rule.reporting, - parent=rule.parent, - ), - ) - messages.extend(temp_messages) - if not success: - return messages, False - - else: - temp_messages, success = self.evaluate( - unmodified_entities, - config=Notification( - entity_name=entity_name, - expression=f"NOT ({rule.expression})", - reporting=rule.reporting, - parent=rule.parent, - ), - ) - messages.extend(temp_messages) - if not success: - return messages, False - - if filter_column_names: - success_condition = " AND ".join( - [f"({c_name} IS NOT NULL AND {c_name})" for c_name in filter_column_names] - ) - temp_messages, success = self.evaluate( - modified_entities, - config=ImmediateFilter( - entity_name=entity_name, - expression=success_condition, - parent=ParentMetadata( - rule="FilterStageRecordLevelFilterApplication", index=0, stage="Sync" - ), - ), + with BackgroundMessageWriter( + working_directory=working_directory, + dve_stage="business_rules", + key_fields=key_fields, + logger=self.logger, + ) as msg_writer: + for entity_name, filter_rules in filters_by_entity.items(): + self.logger.info(f"Applying filters to {entity_name}") + entity = entities[entity_name] + + filter_column_names: list[str] = [] + unmodified_entities = {entity_name: entity} + modified_entities = {entity_name: entity} + + for rule in filter_rules: + self.logger.info(f"Applying filter {rule.reporting.code}") + if rule.reporting.emit == "record_failure": + column_name = f"filter_{uuid4().hex}" + filter_column_names.append(column_name) + temp_messages, success = self.evaluate( + modified_entities, + config=ColumnAddition( + entity_name=entity_name, + column_name=column_name, + expression=rule.expression, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while applying filter logic", + messages=[msg.error_message for msg in temp_messages if msg.error_message], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + temp_messages, success = self.evaluate( + modified_entities, + config=Notification( + entity_name=entity_name, + expression=f"NOT {column_name}", + excluded_columns=filter_column_names, + reporting=rule.reporting, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + self.logger.info( + f"Filter {rule.reporting.code} found {len(temp_messages)} issues" + ) + + else: + temp_messages, success = self.evaluate( + unmodified_entities, + config=Notification( + entity_name=entity_name, + expression=f"NOT ({rule.expression})", + reporting=rule.reporting, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + self.logger.info( + f"Filter {rule.reporting.code} found {len(temp_messages)} issues" + ) + + if filter_column_names: + self.logger.info( + f"Filtering records from entity {entity_name} for error code {rule.reporting.code}" # pylint: disable=line-too-long ) - messages.extend(temp_messages) - if not success: - return messages, False - - for index, filter_column_name in enumerate(filter_column_names): + success_condition = " AND ".join( + [f"({c_name} IS NOT NULL AND {c_name})" for c_name in filter_column_names] + ) temp_messages, success = self.evaluate( modified_entities, - config=ColumnRemoval( + config=ImmediateFilter( entity_name=entity_name, - column_name=filter_column_name, + expression=success_condition, parent=ParentMetadata( - rule="FilterStageRecordLevelFilterColumnRemoval", - index=index, + rule="FilterStageRecordLevelFilterApplication", + index=0, stage="Sync", ), ), ) - messages.extend(temp_messages) if not success: - return messages, False - - entities.update(modified_entities) - - return messages, True - - def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messages: + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while filtering error records", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + for index, filter_column_name in enumerate(filter_column_names): + temp_messages, success = self.evaluate( + modified_entities, + config=ColumnRemoval( + entity_name=entity_name, + column_name=filter_column_name, + parent=ParentMetadata( + rule="FilterStageRecordLevelFilterColumnRemoval", + index=index, + stage="Sync", + ), + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + entities.update(modified_entities) + + return feedback_errors_uri, True + + def apply_rules( + self, + working_directory: URI, + entities: Entities, + rule_metadata: RuleMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[URI, bool]: """Create rule definitions from the metadata for a given dataset and evaluate the impact on the provided entities, returning a deque of messages and altering the entities in-place. """ + self.logger.info("Applying business rules") rules_and_locals: Iterable[tuple[Rule, TemplateVariables]] + errors_uri = get_feedback_errors_uri(working_directory, "business_rules") if rule_metadata.templating_strategy == "upfront": rules_and_locals = [] for rule, local_variables in rule_metadata: @@ -471,7 +566,8 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag else: rules_and_locals = rule_metadata - messages: Messages = [] + pre_sync_messages: Messages = [] + self.logger.info("Applying pre-sync steps") for rule, local_variables in rules_and_locals: for step in rule.pre_sync_steps: if rule_metadata.templating_strategy == "runtime": @@ -480,9 +576,27 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) stage_messages, success = self.evaluate(entities, config=step) - messages.extend(stage_messages) + # if failure, write out processing issues and all prior messages (so nothing lost) if not success: - return messages + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while applying pre filter steps", + [msg.error_message for msg in stage_messages], + ) + ], + ) + if pre_sync_messages: + dump_feedback_errors(working_directory, "business_rules", pre_sync_messages) + + return processing_errors_uri, False + # if not a failure, ensure we keep track of any informational messages + pre_sync_messages.extend(stage_messages) + # if all successful, ensure we write out all informational messages + if pre_sync_messages: + dump_feedback_errors(working_directory, "business_rules", pre_sync_messages) sync_steps = [] for rule, local_variables in rules_and_locals: @@ -493,10 +607,15 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) sync_steps.append(step) - stage_messages, success = self.apply_sync_filters(entities, *sync_steps) - messages.extend(stage_messages) + # error writing handled in apply_sync_filters + errors_uri, success = self.apply_sync_filters( + working_directory, entities, *sync_steps, key_fields=key_fields + ) if not success: - return messages + return errors_uri, False + + post_sync_messages: Messages = [] + self.logger.info("Applying post-sync steps") for rule, local_variables in rules_and_locals: for step in rule.post_sync_steps: @@ -506,10 +625,29 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) stage_messages, success = self.evaluate(entities, config=step) - messages.extend(stage_messages) if not success: - return messages - return messages + processing_errors_uri = dump_processing_errors( + working_directory, + "business_rules", + [ + CriticalProcessingError( + "Issue occurred while applying post filter steps", + [msg.error_message for msg in stage_messages], + ) + ], + ) + if post_sync_messages: + dump_feedback_errors( + working_directory, "business_rules", post_sync_messages + ) + + return processing_errors_uri, False + # if not a failure, ensure we keep track of any informational messages + post_sync_messages.extend(stage_messages) + # if all successful, ensure we write out all informational messages + if post_sync_messages: + dump_feedback_errors(working_directory, "business_rules", post_sync_messages) + return errors_uri, True def read_parquet(self, path: URI, **kwargs) -> EntityType: """Method to read parquet files""" diff --git a/src/dve/core_engine/backends/base/utilities.py b/src/dve/core_engine/backends/base/utilities.py index 30efc74..f55bc88 100644 --- a/src/dve/core_engine/backends/base/utilities.py +++ b/src/dve/core_engine/backends/base/utilities.py @@ -5,8 +5,12 @@ from collections.abc import Sequence from typing import Optional +import pyarrow # type: ignore +import pyarrow.parquet as pq # type: ignore + from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import ExpressionArray, MultiExpression +from dve.parser.type_hints import URI BRACKETS = {"(": ")", "{": "}", "[": "]", "<": ">"} """A mapping of opening brackets to their closing counterpart.""" @@ -131,3 +135,12 @@ def _get_non_heterogenous_type(types: Sequence[type]) -> type: + f"union types (got {type_list!r}) but nullable types are okay" ) return type_list[0] + + +def check_if_parquet_file(file_location: URI) -> bool: + """Check if a file path is valid parquet""" + try: + pq.ParquetFile(file_location) + return True + except (pyarrow.ArrowInvalid, pyarrow.ArrowIOError): + return False diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 5113da5..af2ddba 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -3,21 +3,32 @@ # pylint: disable=R0903 import logging from collections.abc import Iterator +from functools import partial +from multiprocessing import Pool from typing import Any, Optional from uuid import uuid4 import pandas as pd import polars as pl +import pyarrow.parquet as pq # type: ignore from duckdb import DuckDBPyConnection, DuckDBPyRelation from duckdb.typing import DuckDBPyType from polars.datatypes.classes import DataTypeClass as PolarsType from pydantic import BaseModel from pydantic.fields import ModelField +import dve.parser.file_handling as fh +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.contract import BaseDataContract -from dve.core_engine.backends.base.utilities import generate_error_casting_entity_message +from dve.core_engine.backends.base.utilities import ( + check_if_parquet_file, + generate_error_casting_entity_message, +) from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( - coerce_inferred_numpy_array_to_list, duckdb_read_parquet, duckdb_write_parquet, get_duckdb_type_from_annotation, @@ -28,8 +39,8 @@ from dve.core_engine.backends.types import StageSuccessful from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model from dve.core_engine.message import FeedbackMessage -from dve.core_engine.type_hints import URI, Messages -from dve.core_engine.validation import RowValidator +from dve.core_engine.type_hints import URI, EntityLocations +from dve.core_engine.validation import RowValidator, apply_row_validator_helper class PandasApplyHelper: @@ -49,7 +60,8 @@ def __call__(self, row: pd.Series): class DuckDBDataContract(BaseDataContract[DuckDBPyRelation]): """An implementation of a data contract in DuckDB. - This utilises the conversion from relation to pandas dataframe to apply the data contract. + This utilises pyarrow to distibute parquet data across python processes and + a background process to write error messages. """ @@ -71,8 +83,8 @@ def connection(self) -> DuckDBPyConnection: """The duckdb connection""" return self._connection - def _cache_records(self, relation: DuckDBPyRelation, cache_prefix: URI) -> URI: - chunk_uri = "/".join((cache_prefix.rstrip("/"), str(uuid4()))) + ".parquet" + def _cache_records(self, relation: DuckDBPyRelation, working_dir: URI) -> URI: + chunk_uri = "/".join((working_dir.rstrip("/"), str(uuid4()))) + ".parquet" self.write_parquet(entity=relation, target_location=chunk_uri) return chunk_uri @@ -98,75 +110,108 @@ def generate_ddb_cast_statement( return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"' return f'cast(NULL AS {dtype}) AS "{column_name}"' + # pylint: disable=R0914 def apply_data_contract( - self, entities: DuckDBEntities, contract_metadata: DataContractMetadata - ) -> tuple[DuckDBEntities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: DuckDBEntities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[DuckDBEntities, URI, StageSuccessful]: """Apply the data contract to the duckdb relations""" self.logger.info("Applying data contracts") - all_messages: Messages = [] + feedback_errors_uri: URI = get_feedback_errors_uri(working_dir, "data_contract") + + # check if entities are valid parquet - if not, convert + for entity, entity_loc in entity_locations.items(): + if not check_if_parquet_file(entity_loc): + parquet_uri = self.write_parquet( + entities[entity], fh.joinuri(fh.get_parent(entity_loc), f"{entity}.parquet") + ) + entity_locations[entity] = parquet_uri successful = True - for entity_name, relation in entities.items(): - # get dtypes for all fields -> python data types or use with relation - entity_fields: dict[str, ModelField] = contract_metadata.schemas[entity_name].__fields__ - ddb_schema: dict[str, DuckDBPyType] = { - fld.name: get_duckdb_type_from_annotation(fld.annotation) - for fld in entity_fields.values() - } - polars_schema: dict[str, PolarsType] = { - fld.name: get_polars_type_from_annotation(fld.annotation) - for fld in entity_fields.values() - } - if relation_is_empty(relation): - self.logger.warning(f"+ Empty relation for {entity_name}") - empty_df = pl.DataFrame([], schema=polars_schema) # type: ignore # pylint: disable=W0612 - relation = self._connection.sql("select * from empty_df") - continue - - self.logger.info(f"+ Applying contract to: {entity_name}") - - row_validator = contract_metadata.validators[entity_name] - application_helper = PandasApplyHelper(row_validator) - self.logger.info("+ Applying data contract") - coerce_inferred_numpy_array_to_list(relation.df()).apply( - application_helper, axis=1 - ) # pandas uses eager evaluation so potential memory issue here? - all_messages.extend(application_helper.errors) - - casting_statements = [ - ( - self.generate_ddb_cast_statement(column, dtype) - if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + + with BackgroundMessageWriter( + working_dir, "data_contract", key_fields=key_fields + ) as msg_writer: + for entity_name, relation in entities.items(): + # get dtypes for all fields -> python data types or use with relation + entity_fields: dict[str, ModelField] = contract_metadata.schemas[ + entity_name + ].__fields__ + ddb_schema: dict[str, DuckDBPyType] = { + fld.name: get_duckdb_type_from_annotation(fld.annotation) + for fld in entity_fields.values() + } + polars_schema: dict[str, PolarsType] = { + fld.name: get_polars_type_from_annotation(fld.annotation) + for fld in entity_fields.values() + } + if relation_is_empty(relation): + self.logger.warning(f"+ Empty relation for {entity_name}") + empty_df = pl.DataFrame([], schema=polars_schema) # type: ignore # pylint: disable=W0612 + relation = self._connection.sql("select * from empty_df") + continue + + self.logger.info(f"+ Applying contract to: {entity_name}") + + row_validator_helper = partial( + apply_row_validator_helper, + row_validator=contract_metadata.validators[entity_name], ) - for column, dtype in ddb_schema.items() - ] - try: - relation = relation.project(", ".join(casting_statements)) - except Exception as err: # pylint: disable=broad-except - successful = False - self.logger.error(f"Error in casting relation: {err}") - all_messages.append(generate_error_casting_entity_message(entity_name)) - continue - - if self.debug: - # count will force evaluation - only done in debug - pre_convert_row_count = relation.count("*").fetchone()[0] # type: ignore - self.logger.info(f"+ Converting to parquet: ({pre_convert_row_count} rows)") - else: - pre_convert_row_count = 0 - self.logger.info("+ Converting to parquet") - - entities[entity_name] = relation - if self.debug: - post_convert_row_count = entities[entity_name].count("*").fetchone()[0] # type: ignore # pylint:disable=line-too-long - self.logger.info(f"+ Converted to parquet: ({post_convert_row_count} rows)") - if post_convert_row_count != pre_convert_row_count: - raise ValueError( - f"Row count mismatch for {entity_name}" - f" ({pre_convert_row_count} vs {post_convert_row_count})" - ) - else: - self.logger.info("+ Converted to parquet") - return entities, all_messages, successful + batches = pq.ParquetFile(entity_locations[entity_name]).iter_batches(10000) + msg_count = 0 + with Pool(8) as pool: + for msgs in pool.imap_unordered(row_validator_helper, batches): + if msgs: + msg_writer.write_queue.put(msgs) + msg_count += len(msgs) + + self.logger.info( + f"Data contract found {msg_count} issues in {entity_name}" + ) + + casting_statements = [ + ( + self.generate_ddb_cast_statement(column, dtype) + if column in relation.columns + else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + ) + for column, dtype in ddb_schema.items() + ] + try: + relation = relation.project(", ".join(casting_statements)) + except Exception as err: # pylint: disable=broad-except + successful = False + self.logger.error(f"Error in casting relation: {err}") + dump_processing_errors( + working_dir, + "data_contract", + [generate_error_casting_entity_message(entity_name)], + ) + continue + + if self.debug: + # count will force evaluation - only done in debug + pre_convert_row_count = relation.count("*").fetchone()[0] # type: ignore + self.logger.info(f"+ Converting to parquet: ({pre_convert_row_count} rows)") + else: + pre_convert_row_count = 0 + self.logger.info("+ Converting to parquet") + + entities[entity_name] = relation + if self.debug: + post_convert_row_count = entities[entity_name].count("*").fetchone()[0] # type: ignore # pylint:disable=line-too-long + self.logger.info(f"+ Converted to parquet: ({post_convert_row_count} rows)") + if post_convert_row_count != pre_convert_row_count: + raise ValueError( + f"Row count mismatch for {entity_name}" + f" ({pre_convert_row_count} vs {post_convert_row_count})" + ) + else: + self.logger.info("+ Converted to parquet") + + return entities, feedback_errors_uri, successful diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py index 3998bf5..ff65d9f 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py @@ -16,6 +16,7 @@ get_duckdb_type_from_annotation, ) from dve.core_engine.backends.implementations.duckdb.types import SQLType +from dve.core_engine.backends.readers.utilities import check_csv_header_expected from dve.core_engine.backends.utilities import get_polars_type_from_annotation from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, EntityName @@ -24,7 +25,14 @@ @duckdb_write_parquet class DuckDBCSVReader(BaseFileReader): - """A reader for CSV files""" + """A reader for CSV files including the ability to compare the passed model + to the file header, if it exists. + + field_check: flag to compare submitted file header to the accompanying pydantic model + field_check_error_code: The error code to provide if the file header doesn't contain + the expected fields + field_check_error_message: The error message to provide if the file header doesn't contain + the expected fields""" # TODO - the read_to_relation should include the schema and determine whether to # TODO - stringify or not @@ -35,15 +43,43 @@ def __init__( delim: str = ",", quotechar: str = '"', connection: Optional[DuckDBPyConnection] = None, + field_check: bool = False, + field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch", + field_check_error_message: Optional[str] = "The submitted header is missing fields", **_, ): self.header = header self.delim = delim self.quotechar = quotechar self._connection = connection if connection else default_connection + self.field_check = field_check + self.field_check_error_code = field_check_error_code + self.field_check_error_message = field_check_error_message super().__init__() + def perform_field_check( + self, resource: URI, entity_name: str, expected_schema: type[BaseModel] + ): + """Check that the header of the CSV aligns with the provided model""" + if not self.header: + raise ValueError("Cannot perform field check without a CSV header") + + if missing := check_csv_header_expected(resource, expected_schema, self.delim): + raise MessageBearingError( + "The CSV header doesn't match what is expected", + messages=[ + FeedbackMessage( + entity=entity_name, + record=None, + failure_type="submission", + error_location="Whole File", + error_code=self.field_check_error_code, + error_message=f"{self.field_check_error_message} - missing fields: {missing}", # pylint: disable=line-too-long + ) + ], + ) + def read_to_py_iterator( self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> Iterator[dict[str, Any]]: @@ -58,6 +94,9 @@ def read_to_relation( # pylint: disable=unused-argument if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") + if self.field_check: + self.perform_field_check(resource, entity_name, schema) + reader_options: dict[str, Any] = { "header": self.header, "delimiter": self.delim, @@ -89,6 +128,9 @@ def read_to_relation( # pylint: disable=unused-argument if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") + if self.field_check: + self.perform_field_check(resource, entity_name, schema) + reader_options: dict[str, Any] = { "has_header": self.header, "separator": self.delim, @@ -132,6 +174,17 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader): | shop1 | clothes | 2025-01-01 | """ + def __init__( + self, + *args, + non_unique_header_error_code: Optional[str] = "NonUniqueHeader", + non_unique_header_error_message: Optional[str] = None, + **kwargs, + ): + self._non_unique_header_code = non_unique_header_error_code + self._non_unique_header_message = non_unique_header_error_message + super().__init__(*args, **kwargs) + @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument self, resource: URI, entity_name: EntityName, schema: type[BaseModel] @@ -156,10 +209,12 @@ def read_to_relation( # pylint: disable=unused-argument failure_type="submission", error_message=( f"Found {no_records} distinct combination of header values." + if not self._non_unique_header_message + else self._non_unique_header_message ), error_location=entity_name, category="Bad file", - error_code="NonUniqueHeader", + error_code=self._non_unique_header_code, ) ], ) diff --git a/src/dve/core_engine/backends/implementations/spark/contract.py b/src/dve/core_engine/backends/implementations/spark/contract.py index bbd2d5a..d8078bd 100644 --- a/src/dve/core_engine/backends/implementations/spark/contract.py +++ b/src/dve/core_engine/backends/implementations/spark/contract.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterator +from itertools import islice from typing import Any, Optional from uuid import uuid4 @@ -11,6 +12,11 @@ from pyspark.sql.functions import col, lit from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.contract import BaseDataContract, reader_override from dve.core_engine.backends.base.utilities import generate_error_casting_entity_message from dve.core_engine.backends.exceptions import ( @@ -29,7 +35,7 @@ from dve.core_engine.backends.readers import CSVFileReader from dve.core_engine.backends.types import StageSuccessful from dve.core_engine.constants import ROWID_COLUMN_NAME -from dve.core_engine.type_hints import URI, EntityName, Messages +from dve.core_engine.type_hints import URI, EntityLocations, EntityName COMPLEX_TYPES: set[type[DataType]] = {StructType, ArrayType, MapType} """Spark types indicating complex types.""" @@ -61,12 +67,12 @@ def __init__( super().__init__(logger, **kwargs) - def _cache_records(self, dataframe: DataFrame, cache_prefix: URI) -> URI: + def _cache_records(self, dataframe: DataFrame, working_dir: URI) -> URI: """Write a chunk of records out to the cache dir, returning the path to the parquet file. """ - chunk_uri = "/".join((cache_prefix.rstrip("/"), str(uuid4()))) + ".parquet" + chunk_uri = "/".join((working_dir.rstrip("/"), str(uuid4()))) + ".parquet" dataframe.write.parquet(chunk_uri) return chunk_uri @@ -79,10 +85,17 @@ def create_entity_from_py_iterator( ) def apply_data_contract( - self, entities: SparkEntities, contract_metadata: DataContractMetadata - ) -> tuple[SparkEntities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: SparkEntities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[SparkEntities, URI, StageSuccessful]: self.logger.info("Applying data contracts") - all_messages: Messages = [] + + entity_locations = {} if not entity_locations else entity_locations + feedback_errors_uri = get_feedback_errors_uri(working_dir, "data_contract") successful = True for entity_name, record_df in entities.items(): @@ -112,8 +125,18 @@ def apply_data_contract( record_df.rdd.map(lambda row: row.asDict(True)).map(row_validator) # .persist(storageLevel=StorageLevel.MEMORY_AND_DISK) ) - messages = validated.flatMap(lambda row: row[1]).filter(bool) - all_messages.extend(messages.collect()) + with BackgroundMessageWriter( + working_dir, "data_contract", key_fields, self.logger + ) as msg_writer: + messages = validated.flatMap(lambda row: row[1]).filter(bool).toLocalIterator() + msg_count = 0 + while True: + batch = list(islice(messages, 10000)) + if not batch: + break + msg_writer.write_queue.put(batch) + msg_count += len(batch) + self.logger.info(f"Data contract found {msg_count} issues in {entity_name}") try: record_df = record_df.select( @@ -126,7 +149,11 @@ def apply_data_contract( except Exception as err: # pylint: disable=broad-except successful = False self.logger.error(f"Error in converting to dataframe: {err}") - all_messages.append(generate_error_casting_entity_message(entity_name)) + dump_processing_errors( + working_dir, + "data_contract", + [generate_error_casting_entity_message(entity_name)], + ) continue if self.debug: @@ -152,7 +179,7 @@ def apply_data_contract( else: self.logger.info("+ Converted to Dataframe") - return entities, all_messages, successful + return entities, feedback_errors_uri, successful @reader_override(CSVFileReader) def read_csv_file( diff --git a/src/dve/core_engine/backends/readers/utilities.py b/src/dve/core_engine/backends/readers/utilities.py new file mode 100644 index 0000000..642c0b2 --- /dev/null +++ b/src/dve/core_engine/backends/readers/utilities.py @@ -0,0 +1,21 @@ +"""General utilities for file readers""" + +from typing import Optional + +from pydantic import BaseModel + +from dve.core_engine.type_hints import URI +from dve.parser.file_handling.service import open_stream + + +def check_csv_header_expected( + resource: URI, + expected_schema: type[BaseModel], + delimiter: Optional[str] = ",", + quote_char: str = '"', +) -> set[str]: + """Check the header of a CSV matches the expected fields""" + with open_stream(resource) as fle: + header_fields = fle.readline().rstrip().replace(quote_char, "").split(delimiter) + expected_fields = expected_schema.__fields__.keys() + return set(expected_fields).difference(header_fields) diff --git a/src/dve/core_engine/backends/utilities.py b/src/dve/core_engine/backends/utilities.py index bfa6f90..9261806 100644 --- a/src/dve/core_engine/backends/utilities.py +++ b/src/dve/core_engine/backends/utilities.py @@ -4,8 +4,7 @@ from dataclasses import is_dataclass from datetime import date, datetime, time from decimal import Decimal -from typing import GenericAlias # type: ignore -from typing import Any, ClassVar, Union +from typing import Any, ClassVar, GenericAlias, Union # type: ignore import polars as pl # type: ignore from polars.datatypes.classes import DataTypeClass as PolarsType diff --git a/src/dve/core_engine/engine.py b/src/dve/core_engine/engine.py index 87ab0b6..28a2ac5 100644 --- a/src/dve/core_engine/engine.py +++ b/src/dve/core_engine/engine.py @@ -1,12 +1,8 @@ """The core engine for the data validation engine.""" -import csv import json import logging -import shutil -from contextlib import ExitStack from pathlib import Path -from tempfile import NamedTemporaryFile from types import TracebackType from typing import Any, Optional, Union @@ -21,16 +17,9 @@ from dve.core_engine.configuration.v1 import V1EngineConfig from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger -from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import EngineRunValidation, SubmissionInfo -from dve.core_engine.type_hints import EntityName, JSONstring, Messages -from dve.parser.file_handling import ( - TemporaryPrefix, - get_resource_exists, - joinuri, - open_stream, - resolve_location, -) +from dve.core_engine.type_hints import EntityName, JSONstring +from dve.parser.file_handling import TemporaryPrefix, get_resource_exists, joinuri, resolve_location from dve.parser.type_hints import URI, Location @@ -47,7 +36,7 @@ class Config: # pylint: disable=too-few-public-methods """The backend configuration for the given run.""" dataset_config_uri: URI """The dischema location for the current run""" - output_prefix_uri: URI = Field(default_factory=lambda: Path("outputs").resolve().as_uri()) + output_prefix_uri: URI = Field(default_factory=lambda: Path("outputs").resolve().as_posix()) """The prefix for the parquet outputs.""" main_log: logging.Logger = Field(default_factory=lambda: get_logger("CoreEngine")) """The `logging.Logger instance for the data ingest process.""" @@ -129,14 +118,19 @@ def build( main_log.info(f"Debug mode: {debug}") if isinstance(dataset_config_path, Path): - dataset_config_uri = dataset_config_path.resolve().as_uri() + dataset_config_uri = dataset_config_path.resolve().as_posix() else: dataset_config_uri = dataset_config_path + if isinstance(output_prefix, Path): + output_prefix_uri = output_prefix.resolve().as_posix() + else: + output_prefix_uri = output_prefix + backend_config = V1EngineConfig.load(dataset_config_uri) self = cls( dataset_config_uri=dataset_config_uri, - output_prefix_uri=output_prefix, + output_prefix_uri=output_prefix_uri, main_log=main_log, cache_prefix_uri=cache_prefix, backend_config=backend_config, @@ -223,64 +217,14 @@ def _write_entity_outputs(self, entities: SparkEntities) -> SparkEntities: return output_entities - def _write_exception_report(self, messages: Messages) -> None: - """Write an exception report to the ouptut prefix. This is currently - a pipe-delimited CSV file containing all the messages emitted by the - pipeline. - - """ - # need to write using temp files and put to s3 with self.fs_impl? - - self.main_log.info(f"Creating exception report in the output dir: {self.output_prefix_uri}") - self.main_log.info("Splitting errors by category") - - contract_metadata = self.backend_config.get_contract_metadata() - with ExitStack() as file_contexts: - critical_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - critical_writer = csv.writer(critical_file, delimiter="|", lineterminator="\n") - - standard_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - standard_writer = csv.writer(standard_file, delimiter="|", lineterminator="\n") - - warning_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - warning_writer = csv.writer(warning_file, delimiter="|", lineterminator="\n") - - for message in messages: - if message.entity: - key_field = contract_metadata.reporting_fields.get(message.entity, None) - else: - key_field = None - message_row = message.to_row(key_field) - - if message.is_critical: - critical_writer.writerow(message_row) - elif message.is_informational: - warning_writer.writerow(message_row) - else: - standard_writer.writerow(message_row) - - output_uri = joinuri(self.output_prefix_uri, "pipeline.errors") - with open_stream(output_uri, "w", encoding="utf-8") as combined_error_file: - combined_error_file.write("|".join(FeedbackMessage.HEADER) + "\n") - for temp_error_file in [critical_file, standard_file, warning_file]: - temp_error_file.seek(0) - shutil.copyfileobj(temp_error_file, combined_error_file) - - def _write_outputs( - self, entities: SparkEntities, messages: Messages, verbose: bool = False - ) -> tuple[SparkEntities, Messages]: + def _write_outputs(self, entities: SparkEntities) -> SparkEntities: """Write the outputs from the pipeline, returning the written entities and messages. """ entities = self._write_entity_outputs(entities) - if verbose: - self._write_exception_report(messages) - else: - self.main_log.info("Skipping exception report") - - return entities, messages + return entities def _show_available_entities(self, entities: SparkEntities, *, verbose: bool = False) -> None: """Print current entities.""" @@ -301,11 +245,9 @@ def _show_available_entities(self, entities: SparkEntities, *, verbose: bool = F def run_pipeline( self, entity_locations: dict[EntityName, URI], - *, - verbose: bool = False, # pylint: disable=unused-argument submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[SparkEntities, Messages]: + ) -> tuple[SparkEntities, URI]: """Run the pipeline, reading in the entities and applying validation and transformation rules, and then write the outputs. @@ -313,11 +255,11 @@ def run_pipeline( references should be valid after the pipeline context exits. """ - entities, messages = self.backend.process_legacy( + entities, errors_uri = self.backend.process_legacy( + self.output_prefix_uri, entity_locations, self.backend_config.get_contract_metadata(), self.backend_config.get_rule_metadata(), - self.cache_prefix, submission_info, ) - return self._write_outputs(entities, messages, verbose=verbose) + return self._write_outputs(entities), errors_uri diff --git a/src/dve/core_engine/exceptions.py b/src/dve/core_engine/exceptions.py index cba7508..4877baf 100644 --- a/src/dve/core_engine/exceptions.py +++ b/src/dve/core_engine/exceptions.py @@ -1,12 +1,8 @@ """Exceptions emitted by the pipeline.""" -from collections.abc import Iterator +import traceback from typing import Optional -from dve.core_engine.backends.implementations.spark.types import SparkEntities -from dve.core_engine.message import FeedbackMessage -from dve.core_engine.type_hints import Messages - class CriticalProcessingError(ValueError): """An exception emitted if critical errors are received.""" @@ -15,26 +11,18 @@ def __init__( self, error_message: str, *args: object, - messages: Optional[Messages], - entities: Optional[SparkEntities] = None + messages: Optional[list[str]] = None, ) -> None: super().__init__(error_message, *args) self.error_message = error_message """The error message explaining the critical processing error.""" self.messages = messages - """The messages gathered at the time the error was emitted.""" - self.entities = entities - """The entities as they exist at the time the error was emitted.""" - - @property - def critical_messages(self) -> Iterator[FeedbackMessage]: - """Critical messages which caused the processing error.""" - yield from filter(lambda message: message.is_critical, self.messages) # type: ignore + """The stacktrace for the messages.""" @classmethod def from_exception(cls, exc: Exception): """Create from broader exception, for recording in processing errors""" - return cls(error_message=repr(exc), entities=None, messages=[]) + return cls(error_message=repr(exc), messages=traceback.format_exception(exc)) class EntityTypeMismatch(TypeError): diff --git a/src/dve/core_engine/message.py b/src/dve/core_engine/message.py index dd580c6..f2a4e52 100644 --- a/src/dve/core_engine/message.py +++ b/src/dve/core_engine/message.py @@ -1,5 +1,6 @@ """Functionality to represent messages.""" +# pylint: disable=C0103 import copy import datetime as dt import json @@ -17,10 +18,16 @@ from dve.core_engine.type_hints import ( EntityName, ErrorCategory, + ErrorCode, + ErrorLocation, + ErrorMessage, + ErrorType, FailureType, Messages, MessageTuple, Record, + ReportingField, + Status, ) from dve.parser.type_hints import FieldName @@ -82,6 +89,45 @@ class Config: # pylint: disable=too-few-public-methods arbitrary_types_allowed = True +# pylint: disable=R0902 +@dataclass +class UserMessage: + """The structure of the message that is used to populate the error report.""" + + Entity: Optional[str] + """The entity that the message pertains to (if applicable).""" + Key: Optional[str] + "The key field(s) in string format to allow users to identify the record" + FailureType: FailureType + "The type of failure" + Status: Status + "Indicating if an error or warning" + ErrorType: ErrorType + "The type of error" + ErrorLocation: ErrorLocation + "The source of the error" + ErrorMessage: ErrorMessage + "The error message to summarise the error" + ErrorCode: ErrorCode + "The error code of the error" + ReportingField: ReportingField + "The field(s) that the error relates to" + Value: Any + "The offending values" + Category: ErrorCategory + "The category of error" + + @property + def is_informational(self) -> bool: + "Indicates whether the message is a warning" + return self.Status == "informational" + + @property + def is_critical(self) -> bool: + "Indicates if the message relates to a processing issue" + return self.FailureType == "integrity" + + @dataclass(config=Config, eq=True) class FeedbackMessage: # pylint: disable=too-many-instance-attributes """Information which affects processing and needs to be feeded back.""" diff --git a/src/dve/core_engine/type_hints.py b/src/dve/core_engine/type_hints.py index 0be3763..cd4aa18 100644 --- a/src/dve/core_engine/type_hints.py +++ b/src/dve/core_engine/type_hints.py @@ -244,3 +244,12 @@ BinaryComparator = Callable[[Any, Any], bool] """Type hint for operator functions""" + +DVEStage = Literal[ + "audit_received", + "file_transformation", + "data_contract", + "business_rules", + "error_report", + "pipeline", +] diff --git a/src/dve/core_engine/validation.py b/src/dve/core_engine/validation.py index 2be101e..f62309b 100644 --- a/src/dve/core_engine/validation.py +++ b/src/dve/core_engine/validation.py @@ -1,8 +1,11 @@ """XML schema/contract configuration.""" +# pylint: disable=E0611 import warnings +from itertools import chain from typing import Optional +from pyarrow.lib import RecordBatch # type: ignore from pydantic import ValidationError from pydantic.main import ModelMetaclass @@ -145,3 +148,13 @@ def handle_warnings(self, record, caught_warnings) -> list[FeedbackMessage]: ) ) return messages + + +def apply_row_validator_helper( + arrow_batch: RecordBatch, *, row_validator: RowValidator +) -> Messages: + """Helper to distribute data efficiently over python processes and then convert + to dictionaries for applying pydantic model""" + return list( + chain.from_iterable(msgs for _, msgs in map(row_validator, arrow_batch.to_pylist())) + ) diff --git a/src/dve/parser/file_handling/implementations/file.py b/src/dve/parser/file_handling/implementations/file.py index eeed3de..76d8b58 100644 --- a/src/dve/parser/file_handling/implementations/file.py +++ b/src/dve/parser/file_handling/implementations/file.py @@ -9,7 +9,7 @@ from typing_extensions import Literal -from dve.parser.exceptions import FileAccessError, UnsupportedSchemeError +from dve.parser.exceptions import FileAccessError from dve.parser.file_handling.helpers import parse_uri from dve.parser.file_handling.implementations.base import BaseFilesystemImplementation from dve.parser.type_hints import URI, NodeType, PathStr, Scheme @@ -20,11 +20,7 @@ def file_uri_to_local_path(uri: URI) -> Path: """Resolve a `file://` URI to a local filesystem path.""" - scheme, hostname, path = parse_uri(uri) - if scheme not in FILE_URI_SCHEMES: # pragma: no cover - raise UnsupportedSchemeError( - f"Local filesystem must use an allowed file URI scheme, got {scheme!r}" - ) + _, hostname, path = parse_uri(uri) path = unquote(path) # Unfortunately Windows is awkward. @@ -54,7 +50,7 @@ def _path_to_uri(self, path: Path) -> URI: easier to create implementations for 'file-like' protocols. """ - return path.as_uri() + return path.as_posix() @staticmethod def _handle_error( diff --git a/src/dve/parser/file_handling/service.py b/src/dve/parser/file_handling/service.py index 0422b4c..9ee9d9f 100644 --- a/src/dve/parser/file_handling/service.py +++ b/src/dve/parser/file_handling/service.py @@ -4,7 +4,7 @@ """ -# pylint: disable=logging-not-lazy +# pylint: disable=logging-not-lazy, unidiomatic-typecheck, protected-access import hashlib import platform import shutil @@ -286,6 +286,7 @@ def create_directory(target_uri: URI): return +# pylint: disable=too-many-branches def _transfer_prefix( source_prefix: URI, target_prefix: URI, overwrite: bool, action: Literal["copy", "move"] ): @@ -296,26 +297,37 @@ def _transfer_prefix( if action not in ("move", "copy"): # pragma: no cover raise ValueError(f"Unsupported action {action!r}, expected one of: 'copy', 'move'") - if not source_prefix.endswith("/"): - source_prefix += "/" - if not target_prefix.endswith("/"): - target_prefix += "/" - source_uris: list[URI] = [] target_uris: list[URI] = [] source_impl = _get_implementation(source_prefix) target_impl = _get_implementation(target_prefix) + if type(source_impl) == LocalFilesystemImplementation: + source_prefix = source_impl._path_to_uri(source_impl._uri_to_path(source_prefix)) + + if type(target_impl) == LocalFilesystemImplementation: + target_prefix = target_impl._path_to_uri(target_impl._uri_to_path(target_prefix)) + + if not source_prefix.endswith("/"): + source_prefix += "/" + if not target_prefix.endswith("/"): + target_prefix += "/" + for source_uri, node_type in source_impl.iter_prefix(source_prefix, True): if node_type != "resource": continue if not source_uri.startswith(source_prefix): # pragma: no cover - raise FileAccessError( - f"Listed URI ({source_uri!r}) not relative to source prefix " - + f"({source_prefix!r})" - ) + if type(_get_implementation(source_uri)) == LocalFilesystemImplementation: + # Due to local file systems having issues with local file scheme, + # stripping this check off + pass + else: + raise FileAccessError( + f"Listed URI ({source_uri!r}) not relative to source prefix " + + f"({source_prefix!r})" + ) path_within_prefix = source_uri[len(source_prefix) :] target_uri = target_prefix + path_within_prefix @@ -359,11 +371,11 @@ def move_prefix(source_prefix: URI, target_prefix: URI, overwrite: bool = False) def resolve_location(filename_or_url: Location) -> URI: """Resolve a union of filename and URI to a URI.""" if isinstance(filename_or_url, Path): - return filename_or_url.expanduser().resolve().as_uri() + return filename_or_url.expanduser().resolve().as_posix() parsed_url = urlparse(filename_or_url) if parsed_url.scheme == "file": # Passed a URL as a file. - return file_uri_to_local_path(filename_or_url).as_uri() + return file_uri_to_local_path(filename_or_url).as_posix() if platform.system() != "Windows": # On Linux, a filesystem path will never present with a scheme. diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 96156a9..87e927d 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -1,5 +1,6 @@ """DuckDB implementation for `Pipeline` object.""" +import logging from typing import Optional from duckdb import DuckDBPyConnection, DuckDBPyRelation @@ -21,6 +22,7 @@ class DDBDVEPipeline(BaseDVEPipeline): Modified Pipeline class for running a DVE Pipeline with Spark """ + # pylint: disable=R0913 def __init__( self, processed_files_path: URI, @@ -30,6 +32,7 @@ def __init__( submitted_files_path: Optional[URI], reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, + logger: Optional[logging.Logger] = None, ): self._connection = connection super().__init__( @@ -41,6 +44,7 @@ def __init__( submitted_files_path, reference_data_loader, job_run_id, + logger, ) # pylint: disable=arguments-differ diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 5e0b757..7a597cd 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -1,8 +1,11 @@ # pylint: disable=W0223 """A duckdb pipeline for running on Foundry platform""" +import shutil +from pathlib import Path from typing import Optional +from dve.common.error_utils import dump_processing_errors from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( duckdb_get_entity_count, duckdb_write_parquet, @@ -15,7 +18,6 @@ from dve.parser.file_handling.service import _get_implementation from dve.pipeline.duckdb_pipeline import DDBDVEPipeline from dve.pipeline.utils import SubmissionStatus -from dve.reporting.utils import dump_processing_errors @duckdb_get_entity_count @@ -23,6 +25,15 @@ class FoundryDDBPipeline(DDBDVEPipeline): """DuckDB pipeline for running on Foundry Platform""" + def _move_submission_to_processing_files_path(self, submission_info: SubmissionInfo): + """Move submitted file to 'processed_files_path'.""" + _submitted_file_location = Path( + self._submitted_files_path, submission_info.file_name_with_ext # type: ignore + ) + _dest = Path(self.processed_files_path, submission_info.submission_id) + _dest.mkdir(parents=True, exist_ok=True) + shutil.copy2(_submitted_file_location, _dest) + def persist_audit_records(self, submission_info: SubmissionInfo) -> URI: """Write out key audit relations to parquet for persisting to datasets""" write_to = fh.joinuri(self.processed_files_path, submission_info.submission_id, "audit/") @@ -46,8 +57,7 @@ def file_transformation( try: return super().file_transformation(submission_info) except Exception as exc: # pylint: disable=W0718 - self._logger.error(f"File transformation raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("File transformation raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), "file_transformation", @@ -62,11 +72,10 @@ def apply_data_contract( try: return super().apply_data_contract(submission_info, submission_status) except Exception as exc: # pylint: disable=W0718 - self._logger.error(f"Apply data contract raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Apply data contract raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), - "contract", + "data_contract", [CriticalProcessingError.from_exception(exc)], ) self._audit_tables.mark_failed(submissions=[submission_info.submission_id]) @@ -78,8 +87,7 @@ def apply_business_rules( try: return super().apply_business_rules(submission_info, submission_status) except Exception as exc: # pylint: disable=W0718 - self._logger.error(f"Apply business rules raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Apply business rules raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), "business_rules", @@ -94,10 +102,11 @@ def error_report( try: return super().error_report(submission_info, submission_status) except Exception as exc: # pylint: disable=W0718 - self._logger.error(f"Error reports raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Error reports raised exception:") sub_stats = None report_uri = None + submission_status = submission_status if submission_status else SubmissionStatus() + submission_status.processing_failed = True dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), "error_report", @@ -113,6 +122,8 @@ def run_pipeline( try: sub_id: str = submission_info.submission_id report_uri = None + if self._submitted_files_path: + self._move_submission_to_processing_files_path(submission_info) self._audit_tables.add_new_submissions(submissions=[submission_info]) self._audit_tables.mark_transform(submission_ids=[sub_id]) sub_info, sub_status = self.file_transformation(submission_info=submission_info) @@ -135,14 +146,15 @@ def run_pipeline( sub_info, sub_status, sub_stats, report_uri = self.error_report( submission_info=submission_info, submission_status=sub_status ) - self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats]) + if sub_stats: + self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats]) except Exception as err: # pylint: disable=W0718 - self._logger.error( - f"During processing of submission_id: {sub_id}, this exception was raised: {err}" + self._logger.exception( + f"During processing of submission_id: {sub_id}, this exception was raised:" ) dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), - "run_pipeline", + "pipeline", [CriticalProcessingError.from_exception(err)], ) self._audit_tables.mark_failed(submissions=[sub_id]) diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 819656a..c5635b8 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -1,6 +1,7 @@ # pylint: disable=protected-access,too-many-instance-attributes,too-many-arguments,line-too-long """Generic Pipeline object to define how DVE should be interacted with.""" import json +import logging import re from collections import defaultdict from collections.abc import Generator, Iterable, Iterator @@ -15,6 +16,12 @@ from pydantic import validate_arguments import dve.reporting.excel_report as er +from dve.common.error_utils import ( + dump_feedback_errors, + dump_processing_errors, + get_feedback_errors_uri, + load_feedback_messages, +) from dve.core_engine.backends.base.auditing import BaseAuditingManager from dve.core_engine.backends.base.contract import BaseDataContract from dve.core_engine.backends.base.core import EntityManager @@ -28,13 +35,12 @@ from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import SubmissionInfo, SubmissionStatisticsRecord -from dve.core_engine.type_hints import URI, FileURI, InfoURI +from dve.core_engine.type_hints import URI, DVEStage, FileURI, InfoURI from dve.parser import file_handling as fh from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation from dve.parser.file_handling.service import _get_implementation from dve.pipeline.utils import SubmissionStatus, deadletter_file, load_config, load_reader from dve.reporting.error_report import ERROR_SCHEMA, calculate_aggregates -from dve.reporting.utils import dump_feedback_errors, dump_processing_errors PERMISSIBLE_EXCEPTIONS: tuple[type[Exception]] = ( FileNotFoundError, # type: ignore @@ -57,6 +63,7 @@ def __init__( submitted_files_path: Optional[URI], reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, + logger: Optional[logging.Logger] = None, ): self._submitted_files_path = submitted_files_path self._processed_files_path = processed_files_path @@ -66,11 +73,16 @@ def __init__( self._audit_tables = audit_tables self._data_contract = data_contract self._step_implementations = step_implementations - self._logger = get_logger(__name__) + self._logger = logger or get_logger(__name__) self._summary_lock = Lock() self._rec_tracking_lock = Lock() self._aggregates_lock = Lock() + if self._data_contract: + self._data_contract.logger = self._logger + if self._step_implementations: + self._step_implementations.logger = self._logger + @property def job_run_id(self) -> Optional[int]: """Unique Identifier for the job/process that is running this Pipeline.""" @@ -101,7 +113,7 @@ def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() - def get_submission_status(self, step_name: str, submission_id: str) -> SubmissionStatus: + def get_submission_status(self, step_name: DVEStage, submission_id: str) -> SubmissionStatus: """Determine submission status of a submission if not explicitly given""" if not (submission_status := self._audit_tables.get_submission_status(submission_id)): self._logger.warning( @@ -183,6 +195,7 @@ def write_file_to_parquet( errors = [] for model_name, model in models.items(): + self._logger.info(f"Transforming {model_name} to stringified parquet") reader: BaseFileReader = load_reader(dataset, model_name, ext) try: if not entity_type: @@ -223,6 +236,7 @@ def audit_received_file_step( self, pool: ThreadPoolExecutor, submitted_files: Iterable[tuple[FileURI, InfoURI]] ) -> tuple[list[SubmissionInfo], list[SubmissionInfo]]: """Set files as being received and mark them for file transformation""" + self._logger.info("Starting audit received file service") audit_received_futures: list[tuple[str, FileURI, Future]] = [] for submission_file in submitted_files: data_uri, metadata_uri = submission_file @@ -244,8 +258,7 @@ def audit_received_file_step( ) continue except Exception as exc: # pylint: disable=W0703 - self._logger.error(f"audit_received_file raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("audit_received_file raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, submission_id), "audit_received", @@ -285,7 +298,7 @@ def file_transformation( """Transform a file from its original format into a 'stringified' parquet file""" if not self.processed_files_path: raise AttributeError("processed files path not provided") - + self._logger.info(f"Applying file transformation to {submission_info.submission_id}") errors: list[FeedbackMessage] = [] submission_status: SubmissionStatus = SubmissionStatus() submission_file_uri: URI = fh.joinuri( @@ -301,8 +314,7 @@ def file_transformation( ) except MessageBearingError as exc: - self._logger.error(f"Unexpected file transformation error: {exc}") - self._logger.exception(exc) + self._logger.exception("Unexpected file transformation error:") errors.extend(exc.messages) if errors: @@ -321,6 +333,7 @@ def file_transformation_step( list[tuple[SubmissionInfo, SubmissionStatus]], list[tuple[SubmissionInfo, SubmissionStatus]] ]: """Step to transform files from their original format into parquet files""" + self._logger.info("Starting file transformation service") file_transform_futures: list[tuple[SubmissionInfo, Future]] = [] for submission_info in submissions_to_process: @@ -352,8 +365,7 @@ def file_transformation_step( ) continue except Exception as exc: # pylint: disable=W0703 - self._logger.error(f"File transformation raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("File transformation raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, sub_info.submission_id), "file_transformation", @@ -393,9 +405,10 @@ def apply_data_contract( self, submission_info: SubmissionInfo, submission_status: Optional[SubmissionStatus] = None ) -> tuple[SubmissionInfo, SubmissionStatus]: """Method for applying the data contract given a submission_info""" + self._logger.info(f"Applying data contract to {submission_info.submission_id}") if not submission_status: submission_status = self.get_submission_status( - "contract", submission_info.submission_id + "data_contract", submission_info.submission_id ) if not self.processed_files_path: raise AttributeError("processed files path not provided") @@ -403,35 +416,36 @@ def apply_data_contract( if not self.rules_path: raise AttributeError("rules path not provided") - read_from = fh.joinuri( - self.processed_files_path, submission_info.submission_id, "transform/" - ) - write_to = fh.joinuri(self.processed_files_path, submission_info.submission_id, "contract/") + working_dir = fh.joinuri(self.processed_files_path, submission_info.submission_id) + + read_from = fh.joinuri(working_dir, "transform/") + write_to = fh.joinuri(working_dir, "data_contract/") + + fh.create_directory(write_to) # simply for local file systems _, config, model_config = load_config(submission_info.dataset_id, self.rules_path) entities = {} + entity_locations = {} for path, _ in fh.iter_prefix(read_from): + entity_locations[fh.get_file_name(path)] = path entities[fh.get_file_name(path)] = self.data_contract.read_parquet(path) - entities, messages, _success = self.data_contract.apply_data_contract( # type: ignore - entities, config.get_contract_metadata() + key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} + + entities, feedback_errors_uri, _success = self.data_contract.apply_data_contract( # type: ignore + working_dir, entities, entity_locations, config.get_contract_metadata(), key_fields ) entitity: self.data_contract.__entity_type__ # type: ignore for entity_name, entitity in entities.items(): self.data_contract.write_parquet(entitity, fh.joinuri(write_to, entity_name)) - key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} - if messages: - dump_feedback_errors( - fh.joinuri(self.processed_files_path, submission_info.submission_id), - "contract", - messages, - key_fields=key_fields, - ) + validation_failed: bool = False + if fh.get_resource_exists(feedback_errors_uri): + messages = load_feedback_messages(feedback_errors_uri) - validation_failed = any(not rule_message.is_informational for rule_message in messages) + validation_failed = any(not user_message.is_informational for user_message in messages) if validation_failed: submission_status.validation_failed = True @@ -446,6 +460,7 @@ def data_contract_step( list[tuple[SubmissionInfo, SubmissionStatus]], list[tuple[SubmissionInfo, SubmissionStatus]] ]: """Step to validate the types of an untyped (stringly typed) parquet file""" + self._logger.info("Starting data contract service") processed_files: list[tuple[SubmissionInfo, SubmissionStatus]] = [] failed_processing: list[tuple[SubmissionInfo, SubmissionStatus]] = [] dc_futures: list[tuple[SubmissionInfo, SubmissionStatus, Future]] = [] @@ -454,7 +469,7 @@ def data_contract_step( sub_status = ( sub_status if sub_status - else self.get_submission_status("contract", info.submission_id) + else self.get_submission_status("data_contract", info.submission_id) ) dc_futures.append( ( @@ -478,11 +493,10 @@ def data_contract_step( ) continue except Exception as exc: # pylint: disable=W0703 - self._logger.error(f"Data Contract raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Data Contract raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, sub_info.submission_id), - "contract", + "data_contract", [CriticalProcessingError.from_exception(exc)], ) sub_status.processing_failed = True @@ -510,10 +524,11 @@ def data_contract_step( def apply_business_rules( self, submission_info: SubmissionInfo, submission_status: Optional[SubmissionStatus] = None - ): + ) -> tuple[SubmissionInfo, SubmissionStatus]: """Apply the business rules to a given submission, the submission may have failed at the data_contract step so this should be passed in as a bool """ + self._logger.info(f"Applying business rules to {submission_info.submission_id}") if not submission_status: submission_status = self.get_submission_status( "business_rules", submission_info.submission_id @@ -532,11 +547,16 @@ def apply_business_rules( raise AttributeError("step implementations has not been provided.") _, config, model_config = load_config(submission_info.dataset_id, self.rules_path) + working_directory: URI = fh.joinuri( + self._processed_files_path, submission_info.submission_id + ) ref_data = config.get_reference_data_config() rules = config.get_rule_metadata() reference_data = self._reference_data_loader(ref_data) # type: ignore entities = {} - contract = fh.joinuri(self.processed_files_path, submission_info.submission_id, "contract") + contract = fh.joinuri( + self.processed_files_path, submission_info.submission_id, "data_contract" + ) for parquet_uri, _ in fh.iter_prefix(contract): file_name = fh.get_file_name(parquet_uri) @@ -553,17 +573,13 @@ def apply_business_rules( entity_manager = EntityManager(entities=entities, reference_data=reference_data) - rule_messages = self.step_implementations.apply_rules(entity_manager, rules) # type: ignore key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} - if rule_messages: - dump_feedback_errors( - fh.joinuri(self.processed_files_path, submission_info.submission_id), - "business_rules", - rule_messages, - key_fields, - ) + self.step_implementations.apply_rules(working_directory, entity_manager, rules, key_fields) # type: ignore + rule_messages = load_feedback_messages( + get_feedback_errors_uri(working_directory, "business_rules") + ) submission_status.validation_failed = ( any(not rule_message.is_informational for rule_message in rule_messages) or submission_status.validation_failed @@ -603,6 +619,7 @@ def business_rule_step( list[tuple[SubmissionInfo, SubmissionStatus]], ]: """Step to apply business rules (Step impl) to a typed parquet file""" + self._logger.info("Starting business rules service") future_files: list[tuple[SubmissionInfo, SubmissionStatus, Future]] = [] for submission_info, submission_status in files: @@ -644,8 +661,7 @@ def business_rule_step( ) continue except Exception as exc: # pylint: disable=W0703 - self._logger.error(f"Business Rules raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Business Rules raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, sub_info.submission_id), "business_rules", @@ -698,15 +714,14 @@ def _get_error_dataframes(self, submission_id: str): errors_dfs = [pl.DataFrame([], schema=ERROR_SCHEMA)] # type: ignore for file, _ in fh.iter_prefix(path): - if fh.get_file_suffix(file) != "json": + if fh.get_file_suffix(file) != "jsonl": continue with fh.open_stream(file) as f: errors = None try: - errors = json.load(f) - except UnicodeDecodeError as exc: - self._logger.error(f"Error reading file: {file}") - self._logger.exception(exc) + errors = [json.loads(err) for err in f.readlines()] + except UnicodeDecodeError: + self._logger.exception(f"Error reading file: {file}") continue if not errors: continue @@ -746,7 +761,7 @@ def error_report( SubmissionInfo, SubmissionStatus, Optional[SubmissionStatisticsRecord], Optional[URI] ]: """Creates the error reports given a submission info and submission status""" - + self._logger.info(f"Generating error report for {submission_info.submission_id}") if not submission_status: submission_status = self.get_submission_status( "error_report", submission_info.submission_id @@ -755,6 +770,7 @@ def error_report( if not self.processed_files_path: raise AttributeError("processed files path not provided") + self._logger.info("Reading error dataframes") errors_df, aggregates = self._get_error_dataframes(submission_info.submission_id) if not submission_status.number_of_records: @@ -793,9 +809,11 @@ def error_report( "error_reports", f"{submission_info.file_name}_{submission_info.file_extension.strip('.')}.xlsx", ) + self._logger.info("Writing error report") with fh.open_stream(report_uri, "wb") as stream: stream.write(er.ExcelFormat.convert_to_bytes(workbook)) + self._logger.info("Publishing error aggregates") self._publish_error_aggregates(submission_info.submission_id, aggregates) return submission_info, submission_status, sub_stats, report_uri @@ -811,6 +829,7 @@ def error_report_step( """Step to produce error reports takes processed files and files that failed file transformation """ + self._logger.info("Starting error reports service") futures: list[tuple[SubmissionInfo, SubmissionStatus, Future]] = [] reports: list[ tuple[SubmissionInfo, SubmissionStatus, Union[None, SubmissionStatisticsRecord], URI] @@ -845,8 +864,7 @@ def error_report_step( ) continue except Exception as exc: # pylint: disable=W0703 - self._logger.error(f"Error reports raised exception: {exc}") - self._logger.exception(exc) + self._logger.exception("Error reports raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, sub_info.submission_id), "error_report", diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 4111cf3..71fdb32 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -1,5 +1,6 @@ """Spark implementation for `Pipeline` object.""" +import logging from concurrent.futures import Executor from typing import Optional @@ -23,6 +24,7 @@ class SparkDVEPipeline(BaseDVEPipeline): Polymorphed Pipeline class for running a DVE Pipeline with Spark """ + # pylint: disable=R0913 def __init__( self, processed_files_path: URI, @@ -32,6 +34,7 @@ def __init__( reference_data_loader: Optional[type[BaseRefDataLoader]] = None, spark: Optional[SparkSession] = None, job_run_id: Optional[int] = None, + logger: Optional[logging.Logger] = None, ): self._spark = spark if spark else SparkSession.builder.getOrCreate() super().__init__( @@ -43,6 +46,7 @@ def __init__( submitted_files_path, reference_data_loader, job_run_id, + logger, ) # pylint: disable=arguments-differ diff --git a/src/dve/reporting/error_report.py b/src/dve/reporting/error_report.py index ba4a4ac..8852fcb 100644 --- a/src/dve/reporting/error_report.py +++ b/src/dve/reporting/error_report.py @@ -1,15 +1,14 @@ """Error report generation""" -import datetime as dt import json from collections import deque from functools import partial from multiprocessing import Pool, cpu_count -from typing import Union import polars as pl from polars import DataFrame, LazyFrame, Utf8, col # type: ignore +from dve.common.error_utils import conditional_cast from dve.core_engine.message import FeedbackMessage from dve.parser.file_handling.service import open_stream @@ -49,22 +48,6 @@ def get_error_codes(error_code_path: str) -> LazyFrame: return pl.DataFrame(df_lists).lazy() # type: ignore -def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]: - """Determines what to do with a value coming back from the error list""" - if isinstance(value, list): - casts = [ - conditional_cast(val, primary_keys, value_separator) for val in value - ] # type: ignore - return value_separator.join( - [f"{pk}: {id}" if pk else "" for pk, id in zip(primary_keys, casts)] - ) - if isinstance(value, dt.date): - return value.isoformat() - if isinstance(value, dict): - return "" - return str(value) - - def _convert_inner_dict(error: FeedbackMessage, key_fields): return { key: ( diff --git a/src/dve/reporting/utils.py b/src/dve/reporting/utils.py deleted file mode 100644 index 8832b6a..0000000 --- a/src/dve/reporting/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Utilities to support reporting""" - -import json -from typing import Optional - -import dve.parser.file_handling as fh -from dve.core_engine.exceptions import CriticalProcessingError -from dve.core_engine.type_hints import URI, Messages -from dve.reporting.error_report import conditional_cast - - -def dump_feedback_errors( - working_folder: URI, - step_name: str, - messages: Messages, - key_fields: Optional[dict[str, list[str]]] = None, -): - """Write out captured feedback error messages.""" - if not working_folder: - raise AttributeError("processed files path not passed") - - if not key_fields: - key_fields = {} - - errors = fh.joinuri(working_folder, "errors", f"{step_name}_errors.json") - processed = [] - - for message in messages: - if message.original_entity is not None: - primary_keys = key_fields.get(message.original_entity, []) - elif message.entity is not None: - primary_keys = key_fields.get(message.entity, []) - else: - primary_keys = [] - - error = message.to_dict( - key_field=primary_keys, - value_separator=" -- ", - max_number_of_values=10, - record_converter=None, - ) - error["Key"] = conditional_cast(error["Key"], primary_keys, value_separator=" -- ") - processed.append(error) - - with fh.open_stream(errors, "a") as f: - json.dump( - processed, - f, - default=str, - ) - - -def dump_processing_errors( - working_folder: URI, step_name: str, errors: list[CriticalProcessingError] -): - """Write out critical processing errors""" - if not working_folder: - raise AttributeError("processed files path not passed") - if not step_name: - raise AttributeError("step name not passed") - if not errors: - raise AttributeError("errors list not passed") - - error_file: URI = fh.joinuri(working_folder, "errors", "processing_errors.json") - processed = [] - - for error in errors: - processed.append( - { - "step_name": step_name, - "error_location": "processing", - "error_level": "integrity", - "error_message": error.error_message, - } - ) - - with fh.open_stream(error_file, "a") as f: - json.dump( - processed, - f, - default=str, - ) diff --git a/tests/features/steps/utilities.py b/tests/features/steps/utilities.py index 30ebe16..aa9adc1 100644 --- a/tests/features/steps/utilities.py +++ b/tests/features/steps/utilities.py @@ -28,7 +28,6 @@ ] SERVICE_TO_STORAGE_PATH_MAPPING: Dict[str, str] = { "file_transformation": "transform", - "data_contract": "contract", } @@ -49,12 +48,12 @@ def load_errors_from_service(processing_folder: Path, service: str) -> pl.DataFr err_location = Path( processing_folder, "errors", - f"{SERVICE_TO_STORAGE_PATH_MAPPING.get(service, service)}_errors.json", + f"{SERVICE_TO_STORAGE_PATH_MAPPING.get(service, service)}_errors.jsonl", ) msgs = [] try: with open(err_location) as errs: - msgs = json.load(errs) + msgs = [json.loads(err) for err in errs.readlines()] except FileNotFoundError: pass diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py index 61920c2..442e4a0 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py @@ -14,8 +14,12 @@ from dve.core_engine.backends.implementations.duckdb.readers.xml import DuckDBXMLStreamReader from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator +from dve.parser.file_handling import get_resource_exists, joinuri +from dve.parser.file_handling.service import get_parent +from dve.common.error_utils import load_feedback_messages from tests.test_core_engine.test_backends.fixtures import ( nested_all_string_parquet, simple_all_string_parquet, @@ -83,15 +87,16 @@ def test_duckdb_data_contract_csv(temp_csv_file): header=True, delim=",", connection=connection ).read_to_entity_type(DuckDBPyRelation, str(uri), "test_ds", stringify_model(mdl)) } + entity_locations: Dict[str, URI] = {"test_ds": str(uri)} data_contract: DuckDBDataContract = DuckDBDataContract(connection) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta) rel: DuckDBPyRelation = entities.get("test_ds") assert dict(zip(rel.columns, rel.dtypes)) == { fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) for fld in mdl.__fields__.values() } - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert stage_successful @@ -150,6 +155,11 @@ def test_duckdb_data_contract_xml(temp_xml_file): ddb_connection=connection, root_tag="root", record_tag="ClassData" ).read_to_relation(str(uri), "class_info", class_model), } + entity_locations: dict[str, URI] = {} + for entity, rel in entities.items(): + loc: URI = joinuri(get_parent(uri.as_posix()), f"{entity}.parquet") + rel.write_parquet(loc, compression="snappy") + entity_locations[entity] = loc dc_meta = DataContractMetadata( reader_metadata={ @@ -178,7 +188,7 @@ def test_duckdb_data_contract_xml(temp_xml_file): ) data_contract: DuckDBDataContract = DuckDBDataContract(connection) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta) header_rel: DuckDBPyRelation = entities.get("test_header") header_expected_schema: Dict[str, DuckDBPyType] = { fld.name: get_duckdb_type_from_annotation(fld.type_) @@ -189,7 +199,7 @@ def test_duckdb_data_contract_xml(temp_xml_file): for fld in class_model.__fields__.values() } class_data_rel: DuckDBPyRelation = entities.get("test_class_info") - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert header_rel.count("*").fetchone()[0] == 1 assert dict(zip(header_rel.columns, header_rel.dtypes)) == header_expected_schema assert class_data_rel.count("*").fetchone()[0] == 2 @@ -237,9 +247,9 @@ def test_ddb_data_contract_read_and_write_basic_parquet( reporting_fields={"simple_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"simple_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["simple_model"].count("*").fetchone()[0] == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("simple_model_output.parquet") @@ -296,9 +306,9 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet): reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["nested_model"].count("*").fetchone()[0] == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("nested_model_output.parquet") @@ -353,12 +363,13 @@ def test_duckdb_data_contract_custom_error_details(nested_all_string_parquet_w_e reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful + messages: list[UserMessage] = [msg for msg in load_feedback_messages(feedback_errors_uri)] assert len(messages) == 2 - messages = sorted(messages, key= lambda x: x.error_code) - assert messages[0].error_code == "SUBFIELDTESTIDBAD" - assert messages[0].error_message == "subfield id is invalid: subfield.id - WRONG" - assert messages[1].error_code == "TESTIDBAD" - assert messages[1].error_message == "id is invalid: id - WRONG" - assert messages[1].entity == "test_rename" \ No newline at end of file + messages = sorted(messages, key= lambda x: x.ErrorCode) + assert messages[0].ErrorCode == "SUBFIELDTESTIDBAD" + assert messages[0].ErrorMessage == "subfield id is invalid: subfield.id - WRONG" + assert messages[1].ErrorCode == "TESTIDBAD" + assert messages[1].ErrorMessage == "id is invalid: id - WRONG" + assert messages[1].Entity == "test_rename" \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py index 8490ab5..2899dc6 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py @@ -1,4 +1,3 @@ -from typing import Dict, List import pytest from dve.core_engine.backends.implementations.duckdb.utilities import ( @@ -16,7 +15,7 @@ ), ], ) -def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str]): +def test_expr_mapping_to_columns(expressions: dict[str, str], expected: list[str]): observed = expr_mapping_to_columns(expressions) assert observed == expected @@ -51,6 +50,7 @@ def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str ), ], ) -def test_expr_array_to_columns(expressions: Dict[str, str], expected: list[str]): +def test_expr_array_to_columns(expressions: dict[str, str], expected: list[str]): observed = expr_array_to_columns(expressions) assert observed == expected + diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py index 789ca1a..921c9be 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py @@ -16,8 +16,11 @@ from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig +from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator +from dve.parser.file_handling.service import get_parent, get_resource_exists +from dve.common.error_utils import load_feedback_messages from tests.test_core_engine.test_backends.fixtures import ( nested_all_string_parquet, nested_all_string_parquet_w_errors, @@ -67,9 +70,9 @@ def test_spark_data_contract_read_and_write_basic_parquet( reporting_fields={"simple_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"simple_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["simple_model"].count() == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("simple_model_output.parquet") @@ -140,9 +143,9 @@ def test_spark_data_contract_read_nested_parquet(nested_all_string_parquet): reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["nested_model"].count() == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("nested_model_output.parquet") @@ -227,14 +230,15 @@ def test_spark_data_contract_custom_error_details(nested_all_string_parquet_w_er reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful + messages: list[UserMessage] = [msg for msg in load_feedback_messages(feedback_errors_uri)] assert len(messages) == 2 - messages = sorted(messages, key= lambda x: x.error_code) - assert messages[0].error_code == "SUBFIELDTESTIDBAD" - assert messages[0].error_message == "subfield id is invalid: subfield.id - WRONG" - assert messages[1].error_code == "TESTIDBAD" - assert messages[1].error_message == "id is invalid: id - WRONG" - assert messages[1].entity == "test_rename" + messages = sorted(messages, key= lambda x: x.ErrorCode) + assert messages[0].ErrorCode == "SUBFIELDTESTIDBAD" + assert messages[0].ErrorMessage == "subfield id is invalid: subfield.id - WRONG" + assert messages[1].ErrorCode == "TESTIDBAD" + assert messages[1].ErrorMessage == "id is invalid: id - WRONG" + assert messages[1].Entity == "test_rename" \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py index 900632d..c326fef 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py @@ -57,7 +57,7 @@ def test_ddb_json_reader_all_str(temp_json_file): expected_fields = [fld for fld in mdl.__fields__] reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type( - DuckDBPyRelation, uri, "test", stringify_model(mdl) + DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl) ) assert rel.columns == expected_fields assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in expected_fields} @@ -68,7 +68,7 @@ def test_ddb_json_reader_cast(temp_json_file): uri, data, mdl = temp_json_file expected_fields = [fld for fld in mdl.__fields__] reader = DuckDBJSONReader() - rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri, "test", mdl) + rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri.as_posix(), "test", mdl) assert rel.columns == expected_fields assert dict(zip(rel.columns, rel.dtypes)) == { @@ -82,7 +82,7 @@ def test_ddb_csv_write_parquet(temp_json_file): uri, _, mdl = temp_json_file reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type( - DuckDBPyRelation, uri, "test", stringify_model(mdl) + DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl) ) target_loc: Path = uri.parent.joinpath("test_parquet.parquet").as_posix() reader.write_parquet(rel, target_loc) diff --git a/tests/test_core_engine/test_backends/test_readers/test_utilities.py b/tests/test_core_engine/test_backends/test_readers/test_utilities.py new file mode 100644 index 0000000..4426769 --- /dev/null +++ b/tests/test_core_engine/test_backends/test_readers/test_utilities.py @@ -0,0 +1,55 @@ +import datetime as dt +from pathlib import Path +import tempfile +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, create_model + +from dve.core_engine.backends.readers.utilities import check_csv_header_expected + +@pytest.mark.parametrize( + ["header_row", "delim", "schema", "expected"], + [ + ( + "field1,field2,field3", + ",", + {"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)}, + set(), + ), + ( + "field2,field3,field1", + ",", + {"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)}, + set(), + ), + ( + "str_field|int_field|date_field|", + ",", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + {"str_field","int_field","date_field"}, + ), + ( + '"str_field"|"int_field"|"date_field"', + "|", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + set(), + ), + ( + 'str_field,int_field,date_field\n', + ",", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + set(), + ), + + ], +) +def test_check_csv_header_expected( + header_row: str, delim: str, schema: type[BaseModel], expected: set[str] +): + mdl = create_model("TestModel", **schema) + with tempfile.TemporaryDirectory() as tmpdir: + fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv") + fle.open("w+").write(header_row) + res = check_csv_header_expected(fle.as_posix(), mdl, delim) + assert res == expected \ No newline at end of file diff --git a/tests/test_core_engine/test_engine.py b/tests/test_core_engine/test_engine.py index 5e16f09..5118cbd 100644 --- a/tests/test_core_engine/test_engine.py +++ b/tests/test_core_engine/test_engine.py @@ -9,6 +9,7 @@ import pytest from pyspark.sql import SparkSession +from dve.common.error_utils import load_all_error_messages from dve.core_engine.backends.implementations.spark.backend import SparkBackend from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.engine import CoreEngine @@ -25,21 +26,21 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( - dataset_config_path=config_path.as_uri(), + dataset_config_path=config_path.as_posix(), output_prefix=Path(temp_dir), - backend=SparkBackend(dataset_config_uri=config_path.parent.as_uri(), + backend=SparkBackend(dataset_config_uri=config_path.parent.as_posix(), spark_session=spark, reference_data_loader=refdata_loader) ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ - "planets": get_test_file_path("planets/planets_demo.csv").as_uri(), + "planets": get_test_file_path("planets/planets_demo.csv").as_posix(), }, ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() @@ -55,7 +56,7 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): def test_dummy_demographics_run(self, spark, temp_dir: str): """Test that we can still run the test example with the dummy demographics data.""" - config_path = get_test_file_path("demographics/basic_demographics.dischema.json").as_uri() + config_path = get_test_file_path("demographics/basic_demographics.dischema.json").as_posix() with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( @@ -64,15 +65,15 @@ def test_dummy_demographics_run(self, spark, temp_dir: str): ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ "demographics": get_test_file_path( "demographics/basic_demographics.csv" - ).as_uri(), + ).as_posix(), }, ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() @@ -88,7 +89,7 @@ def test_dummy_demographics_run(self, spark, temp_dir: str): def test_dummy_books_run(self, spark, temp_dir: str): """Test that we can handle files with more complex nested schemas.""" - config_path = get_test_file_path("books/nested_books.dischema.json").as_uri() + config_path = get_test_file_path("books/nested_books.dischema.json").as_posix() with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( @@ -96,14 +97,14 @@ def test_dummy_books_run(self, spark, temp_dir: str): output_prefix=Path(temp_dir), ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ - "header": get_test_file_path("books/nested_books.xml").as_uri(), - "nested_books": get_test_file_path("books/nested_books.xml").as_uri(), + "header": get_test_file_path("books/nested_books.xml").as_posix(), + "nested_books": get_test_file_path("books/nested_books.xml").as_posix(), } ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() diff --git a/tests/test_parser/test_file_handling.py b/tests/test_parser/test_file_handling.py index cfa90be..8833d17 100644 --- a/tests/test_parser/test_file_handling.py +++ b/tests/test_parser/test_file_handling.py @@ -32,6 +32,7 @@ resolve_location, ) from dve.parser.file_handling.implementations import S3FilesystemImplementation +from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation, file_uri_to_local_path from dve.parser.file_handling.service import _get_implementation from dve.parser.type_hints import Hostname, Scheme, URIPath @@ -192,11 +193,14 @@ def test_iter_prefix(self, prefix: str): file.write("") actual_nodes = sorted(iter_prefix(prefix, recursive=False)) + cleaned_prefix = ( + file_uri_to_local_path(prefix).as_posix() if type(_get_implementation(prefix)) == LocalFilesystemImplementation + else prefix) expected_nodes = sorted( [ - (prefix + "/test_file.txt", "resource"), - (prefix + "/test_sibling.txt", "resource"), - (prefix + "/test_prefix/", "directory"), + (cleaned_prefix + "/test_file.txt", "resource"), + (cleaned_prefix + "/test_sibling.txt", "resource"), + (cleaned_prefix + "/test_prefix/", "directory"), ] ) assert actual_nodes == expected_nodes @@ -221,14 +225,20 @@ def test_iter_prefix_recursive(self, prefix: str): "/test_prefix/another_level/test_file.txt", "/test_prefix/another_level/nested_sibling.txt", ] - resource_uris = [prefix + uri_path for uri_path in structure] - directory_uris = [prefix + "/test_prefix/", prefix + "/test_prefix/another_level/"] + cleaned_prefix = ( + file_uri_to_local_path(prefix).as_posix() + if type(_get_implementation(prefix)) == LocalFilesystemImplementation + else prefix + ) + + resource_uris = [cleaned_prefix + uri_path for uri_path in structure] + directory_uris = [cleaned_prefix + "/test_prefix/", cleaned_prefix + "/test_prefix/another_level/"] for uri in resource_uris: with open_stream(uri, "w") as file: file.write("") - actual_nodes = sorted(iter_prefix(prefix, recursive=True)) + actual_nodes = sorted(iter_prefix(cleaned_prefix, recursive=True)) expected_nodes = sorted( [(uri, "directory") for uri in directory_uris] + [(uri, "resource") for uri in resource_uris] @@ -409,21 +419,21 @@ def test_cursed_s3_keys_supported(temp_s3_prefix: str): [ ( Path("abc/samples/planet_test_records.xml"), - Path("abc/samples/planet_test_records.xml").resolve().as_uri(), + Path("abc/samples/planet_test_records.xml").resolve().as_posix(), ), - ("file:///home/user/file.txt", Path("/home/user/file.txt").as_uri()), + ("file:///home/user/file.txt", Path("/home/user/file.txt").as_posix()), ("s3://bucket/path/within/bucket/file.csv", "s3://bucket/path/within/bucket/file.csv"), ( "file:///abc/samples/planet_test_records.xml", - "file:///abc/samples/planet_test_records.xml", + "/abc/samples/planet_test_records.xml", ), ( "/abc/samples/planet_test_records.xml", - Path("/abc/samples/planet_test_records.xml").as_uri(), + Path("/abc/samples/planet_test_records.xml").as_posix(), ), ( Path("/abc/samples/planet_test_records.xml"), - Path("/abc/samples/planet_test_records.xml").as_uri(), + Path("/abc/samples/planet_test_records.xml").as_posix(), ), ], # fmt: on diff --git a/tests/test_pipeline/pipeline_helpers.py b/tests/test_pipeline/pipeline_helpers.py index 1518ccf..ddd4ef8 100644 --- a/tests/test_pipeline/pipeline_helpers.py +++ b/tests/test_pipeline/pipeline_helpers.py @@ -172,7 +172,7 @@ def planets_data_after_data_contract() -> Iterator[Tuple[SubmissionInfo, str]]: dataset_id="planets", file_extension="json", ) - output_path = Path(tdir, submitted_file_info.submission_id, "contract", "planets") + output_path = Path(tdir, submitted_file_info.submission_id, "data_contract", "planets") output_path.mkdir(parents=True) planet_contract_data = { @@ -220,7 +220,7 @@ def planets_data_after_data_contract_that_break_business_rules() -> Iterator[ dataset_id="planets", file_extension="json", ) - output_path = Path(tdir, submitted_file_info.submission_id, "contract", "planets") + output_path = Path(tdir, submitted_file_info.submission_id, "data_contract", "planets") output_path.mkdir(parents=True) planet_contract_data = { @@ -398,9 +398,10 @@ def error_data_after_business_rules() -> Iterator[Tuple[SubmissionInfo, str]]: } ]""" ) - output_file_path = output_path / "business_rules_errors.json" + output_file_path = output_path / "business_rules_errors.jsonl" with open(output_file_path, "w", encoding="utf-8") as f: - json.dump(error_data, f) + for entry in error_data: + f.write(json.dumps(entry) + "\n") yield submitted_file_info, tdir diff --git a/tests/test_pipeline/test_duckdb_pipeline.py b/tests/test_pipeline/test_duckdb_pipeline.py index 58eb4ac..29e0734 100644 --- a/tests/test_pipeline/test_duckdb_pipeline.py +++ b/tests/test_pipeline/test_duckdb_pipeline.py @@ -132,7 +132,7 @@ def test_data_contract_step( assert len(success) == 1 assert not success[0][1].validation_failed assert len(failed) == 0 - assert Path(processed_file_path, sub_info.submission_id, "contract", "planets").exists() + assert Path(processed_file_path, sub_info.submission_id, "data_contract", "planets").exists() assert pl_row_count(audit_manager.get_all_business_rule_submissions().pl()) == 1 diff --git a/tests/test_pipeline/test_foundry_ddb_pipeline.py b/tests/test_pipeline/test_foundry_ddb_pipeline.py index 49440fc..68fec99 100644 --- a/tests/test_pipeline/test_foundry_ddb_pipeline.py +++ b/tests/test_pipeline/test_foundry_ddb_pipeline.py @@ -5,9 +5,10 @@ from datetime import datetime from pathlib import Path import shutil +import tempfile from uuid import uuid4 -import pytest +import polars as pl from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader @@ -115,4 +116,79 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert not fh.get_resource_exists(report_uri) assert not output_loc + + perror_path = Path( + processing_folder, + sub_info.submission_id, + "processing_errors", + "processing_errors.json" + ) + assert perror_path.exists() + perror_schema = { + "step_name": pl.Utf8(), + "error_location": pl.Utf8(), + "error_level": pl.Utf8(), + "error_message": pl.Utf8(), + "error_traceback": pl.List(pl.Utf8()), + } + expected_error_df = ( + pl.DataFrame( + [ + { + "step_name": "file_transformation", + "error_location": "processing", + "error_level": "integrity", + "error_message": "ReaderLacksEntityTypeSupport()", + "error_traceback": None, + }, + ], + perror_schema + ) + .select(pl.col("step_name"), pl.col("error_location"), pl.col("error_message")) + ) + actual_error_df = ( + pl.read_json(perror_path, schema=perror_schema) + .select(pl.col("step_name"), pl.col("error_location"), pl.col("error_message")) + ) + assert actual_error_df.equals(expected_error_df) + assert len(list(fh.iter_prefix(audit_files))) == 2 + + +def test_foundry_runner_with_submitted_files_path(movies_test_files, temp_ddb_conn): + db_file, conn = temp_ddb_conn + ref_db_file = Path(db_file.parent, "movies_refdata.duckdb").as_posix() + conn.sql(f"ATTACH '{ref_db_file}' AS movies_refdata") + conn.read_parquet( + get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix() + ).to_table("movies_refdata.sequels") + processing_folder = Path(tempfile.mkdtemp()).as_posix() + submitted_files_path = Path(movies_test_files).as_posix() + sub_id = uuid4().hex + sub_info = SubmissionInfo( + submission_id=sub_id, + dataset_id="movies", + file_name="good_movies", + file_extension="json", + submitting_org="TEST", + datetime_received=datetime(2025,11,5) + ) + + DuckDBRefDataLoader.connection = conn + DuckDBRefDataLoader.dataset_config_uri = None + + with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: + dve_pipeline = FoundryDDBPipeline( + processed_files_path=processing_folder, + audit_tables=audit_manager, + connection=conn, + rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), + submitted_files_path=submitted_files_path, + reference_data_loader=DuckDBRefDataLoader, + ) + output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) + + assert Path(processing_folder, sub_id, sub_info.file_name_with_ext).exists() + assert fh.get_resource_exists(report_uri) + assert len(list(fh.iter_prefix(output_loc))) == 2 + assert len(list(fh.iter_prefix(audit_files))) == 3 diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 7f4738f..910626a 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -16,10 +16,12 @@ import polars as pl from pyspark.sql import SparkSession +from dve.common.error_utils import load_feedback_messages from dve.core_engine.backends.base.auditing import FilterCriteria from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations +from dve.core_engine.message import UserMessage from dve.core_engine.models import ProcessingStatusRecord, SubmissionInfo, SubmissionStatisticsRecord import dve.parser.file_handling as fh from dve.pipeline.spark_pipeline import SparkDVEPipeline @@ -135,7 +137,7 @@ def test_apply_data_contract_success( assert not sub_status.validation_failed - assert Path(Path(processed_file_path), sub_info.submission_id, "contract", "planets").exists() + assert Path(Path(processed_file_path), sub_info.submission_id, "data_contract", "planets").exists() def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name @@ -157,9 +159,9 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name assert sub_status.validation_failed output_path = Path(processed_file_path) / sub_info.submission_id - assert Path(output_path, "contract", "planets").exists() + assert Path(output_path, "data_contract", "planets").exists() - errors_path = Path(output_path, "errors", "contract_errors.json") + errors_path = Path(output_path, "errors", "data_contract_errors.jsonl") assert errors_path.exists() expected_errors = [ @@ -203,10 +205,10 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name "Category": "Bad value", }, ] - with open(errors_path, "r", encoding="utf-8") as f: - actual_errors = json.load(f) + + actual_errors = list(load_feedback_messages(errors_path.as_posix())) - assert actual_errors == expected_errors + assert actual_errors == [UserMessage(**err) for err in expected_errors] def test_data_contract_step( @@ -234,7 +236,7 @@ def test_data_contract_step( assert not success[0][1].validation_failed assert len(failed) == 0 - assert Path(processed_file_path, sub_info.submission_id, "contract", "planets").exists() + assert Path(processed_file_path, sub_info.submission_id, "data_contract", "planets").exists() assert audit_manager.get_all_business_rule_submissions().count() == 1 audit_result = audit_manager.get_all_error_report_submissions() @@ -329,7 +331,7 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out assert og_planets_entity_path.exists() assert spark.read.parquet(str(og_planets_entity_path)).count() == 1 - errors_path = Path(br_path.parent, "errors", "business_rules_errors.json") + errors_path = Path(br_path.parent, "errors", "business_rules_errors.jsonl") assert errors_path.exists() expected_errors = [ @@ -360,10 +362,10 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out "Category": "Bad value", }, ] - with open(errors_path, "r", encoding="utf-8") as f: - actual_errors = json.load(f) + + actual_errors = list(load_feedback_messages(errors_path.as_posix())) - assert actual_errors == expected_errors + assert actual_errors == [UserMessage(**err) for err in expected_errors] def test_business_rule_step( diff --git a/tests/test_reporting/__init__.py b/tests/test_reporting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_reporting/test_error_utils.py b/tests/test_reporting/test_error_utils.py new file mode 100644 index 0000000..c5f7045 --- /dev/null +++ b/tests/test_reporting/test_error_utils.py @@ -0,0 +1,51 @@ +"""test utility functions & objects in dve.reporting module""" + +import tempfile +from pathlib import Path + +import polars as pl + +from dve.core_engine.exceptions import CriticalProcessingError +from dve.common.error_utils import dump_processing_errors + +# pylint: disable=C0116 + + +def test_dump_processing_errors(): + perror_schema = { + "step_name": pl.Utf8(), + "error_location": pl.Utf8(), + "error_level": pl.Utf8(), + "error_message": pl.Utf8(), + "error_stacktrace": pl.List(pl.Utf8()), + } + with tempfile.TemporaryDirectory() as temp_dir: + dump_processing_errors( + temp_dir, + "test_step", + [CriticalProcessingError("test error message")] + ) + + output_path = Path(temp_dir, "processing_errors") + + assert output_path.exists() + assert len(list(output_path.iterdir())) == 1 + + expected_df = pl.DataFrame( + [ + { + "step_name": "test_step", + "error_location": "processing", + "error_level": "integrity", + "error_message": "test error message", + "error_stacktrace": None, + }, + ], + perror_schema + ) + error_df = pl.read_json( + Path(output_path, "processing_errors.json") + ) + cols_to_check = ["step_name", "error_location", "error_level", "error_message"] + + assert error_df.select(pl.col(k) for k in cols_to_check).equals(expected_df.select(pl.col(k) for k in cols_to_check)) diff --git a/tests/test_error_reporting/test_excel_report.py b/tests/test_reporting/test_excel_report.py similarity index 100% rename from tests/test_error_reporting/test_excel_report.py rename to tests/test_reporting/test_excel_report.py