diff --git a/docs/README.md b/docs/README.md index 7ab5d92..fc0de4a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -165,8 +165,8 @@ for entity in data_contract_config.schemas: # Data contract step here data_contract = SparkDataContract(spark_session=spark) -entities, validation_messages, success = data_contract.apply_data_contract( - entities, data_contract_config +entities, feedback_errors_uri, success = data_contract.apply_data_contract( + entities, None, data_contract_config ) ``` diff --git a/src/dve/common/__init__.py b/src/dve/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dve/common/error_utils.py b/src/dve/common/error_utils.py new file mode 100644 index 0000000..120c902 --- /dev/null +++ b/src/dve/common/error_utils.py @@ -0,0 +1,189 @@ +"""Utilities to support reporting""" + +import datetime as dt +import json +import logging +from collections.abc import Iterable +from itertools import chain +from multiprocessing import Queue +from threading import Thread +from typing import Optional, Union + +import dve.parser.file_handling as fh +from dve.core_engine.exceptions import CriticalProcessingError +from dve.core_engine.loggers import get_logger +from dve.core_engine.message import UserMessage +from dve.core_engine.type_hints import URI, DVEStageName, Messages + + +def get_feedback_errors_uri(working_folder: URI, step_name: DVEStageName) -> URI: + """Determine the location of json lines file containing all errors generated in a step""" + return fh.joinuri(working_folder, "errors", f"{step_name}_errors.jsonl") + + +def get_processing_errors_uri(working_folder: URI) -> URI: + """Determine the location of json lines file containing all processing + errors generated from DVE run""" + return fh.joinuri(working_folder, "processing_errors", "processing_errors.jsonl") + + +def dump_feedback_errors( + working_folder: URI, + step_name: DVEStageName, + messages: Messages, + key_fields: Optional[dict[str, list[str]]] = None, +) -> URI: + """Write out captured feedback error messages.""" + if not working_folder: + raise AttributeError("processed files path not passed") + + if not key_fields: + key_fields = {} + + error_file = get_feedback_errors_uri(working_folder, step_name) + processed = [] + + for message in messages: + if message.original_entity is not None: + primary_keys = key_fields.get(message.original_entity, []) + elif message.entity is not None: + primary_keys = key_fields.get(message.entity, []) + else: + primary_keys = [] + + error = message.to_dict( + key_field=primary_keys, + value_separator=" -- ", + max_number_of_values=10, + record_converter=None, + ) + error["Key"] = conditional_cast(error["Key"], primary_keys, value_separator=" -- ") + processed.append(error) + + with fh.open_stream(error_file, "a") as f: + f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n") + return error_file + + +def dump_processing_errors( + working_folder: URI, step_name: str, errors: list[CriticalProcessingError] +): + """Write out critical processing errors""" + if not working_folder: + raise AttributeError("processed files path not passed") + if not step_name: + raise AttributeError("step name not passed") + if not errors: + raise AttributeError("errors list not passed") + + error_file: URI = get_processing_errors_uri(working_folder) + processed = [] + + for error in errors: + processed.append( + { + "step_name": step_name, + "error_location": "processing", + "error_level": "integrity", + "error_message": error.error_message, + "error_traceback": error.messages, + } + ) + + with fh.open_stream(error_file, "a") as f: + f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n") + + return error_file + + +def load_feedback_messages(feedback_messages_uri: URI) -> Iterable[UserMessage]: + """Load user messages from jsonl file""" + if not fh.get_resource_exists(feedback_messages_uri): + return + with fh.open_stream(feedback_messages_uri) as errs: + yield from (UserMessage(**json.loads(err)) for err in errs.readlines()) + + +def load_all_error_messages(error_directory_uri: URI) -> Iterable[UserMessage]: + "Load user messages from all jsonl files" + return chain.from_iterable( + [ + load_feedback_messages(err_file) + for err_file, _ in fh.iter_prefix(error_directory_uri) + if err_file.endswith(".jsonl") + ] + ) + + +class BackgroundMessageWriter: + """Controls batch writes to error jsonl files""" + + def __init__( + self, + working_directory: URI, + dve_stage: DVEStageName, + key_fields: Optional[dict[str, list[str]]] = None, + logger: Optional[logging.Logger] = None, + ): + self._working_directory = working_directory + self._dve_stage = dve_stage + self._feedback_message_uri = get_feedback_errors_uri( + self._working_directory, self._dve_stage + ) + self._key_fields = key_fields + self.logger = logger or get_logger(type(self).__name__) + self._write_thread: Optional[Thread] = None + self._queue: Queue = Queue() + + @property + def write_queue(self) -> Queue: # type: ignore + """Queue for storing batches of messages to be written""" + return self._queue + + @property + def write_thread(self) -> Thread: # type: ignore + """Thread to write batches of messages to jsonl file""" + if not self._write_thread: + self._write_thread = Thread(target=self._write_process_wrapper) + return self._write_thread + + def _write_process_wrapper(self): + """Wrapper for dump feedback errors to run in background process""" + # writing thread will block if nothing in queue + while True: + if msgs := self.write_queue.get(): + dump_feedback_errors( + self._working_directory, self._dve_stage, msgs, self._key_fields + ) + else: + break + + def __enter__(self) -> "BackgroundMessageWriter": + self.write_thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + self.logger.exception( + "Issue occured during background write process:", + exc_info=(exc_type, exc_value, traceback), + ) + # None value in queue will trigger break in target + self.write_queue.put(None) + self.write_thread.join() + + +def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]: + """Determines what to do with a value coming back from the error list""" + if isinstance(value, list): + casts = [ + conditional_cast(val, primary_keys, value_separator) for val in value + ] # type: ignore + return value_separator.join( + [f"{pk}: {id}" if pk else "" for pk, id in zip(primary_keys, casts)] + ) + if isinstance(value, dt.date): + return value.isoformat() + if isinstance(value, dict): + return "" + return str(value) diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index bed2e17..9d6abaa 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -17,13 +17,8 @@ from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful from dve.core_engine.loggers import get_logger from dve.core_engine.models import SubmissionInfo -from dve.core_engine.type_hints import ( - URI, - EntityLocations, - EntityName, - EntityParquetLocations, - Messages, -) +from dve.core_engine.type_hints import URI, EntityLocations, EntityName, EntityParquetLocations +from dve.parser.file_handling.service import get_parent, joinuri class BaseBackend(Generic[EntityType], ABC): @@ -148,11 +143,12 @@ def convert_entities_to_spark( def apply( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[Entities, Messages, StageSuccessful]: + ) -> tuple[Entities, URI, StageSuccessful]: """Apply the data contract and the rules, returning the entities and all generated messages. @@ -160,9 +156,11 @@ def apply( reference_data = self.load_reference_data( rule_metadata.reference_data_config, submission_info ) - entities, messages, successful = self.contract.apply(entity_locations, contract_metadata) + entities, dc_feedback_errors_uri, successful, processing_errors_uri = self.contract.apply( + working_dir, entity_locations, contract_metadata + ) if not successful: - return entities, messages, successful + return entities, get_parent(processing_errors_uri), successful for entity_name, entity in entities.items(): entities[entity_name] = self.step_implementations.add_row_id(entity) @@ -170,43 +168,46 @@ def apply( # TODO: Handle entity manager creation errors. entity_manager = EntityManager(entities, reference_data) # TODO: Add stage success to 'apply_rules' - rule_messages = self.step_implementations.apply_rules(entity_manager, rule_metadata) - messages.extend(rule_messages) + # TODO: In case of large errors in business rules, write messages to jsonl file + # TODO: and return uri to errors + _ = self.step_implementations.apply_rules(working_dir, entity_manager, rule_metadata) for entity_name, entity in entity_manager.entities.items(): entity_manager.entities[entity_name] = self.step_implementations.drop_row_id(entity) - return entity_manager.entities, messages, True + return entity_manager.entities, get_parent(dc_feedback_errors_uri), True def process( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, - cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[MutableMapping[EntityName, URI], Messages]: + ) -> tuple[MutableMapping[EntityName, URI], URI]: """Apply the data contract and the rules, write the entities out to parquet and returning the entity locations and all generated messages. """ - entities, messages, successful = self.apply( - entity_locations, contract_metadata, rule_metadata, submission_info + entities, feedback_errors_uri, successful = self.apply( + working_dir, entity_locations, contract_metadata, rule_metadata, submission_info ) if successful: - parquet_locations = self.write_entities_to_parquet(entities, cache_prefix) + parquet_locations = self.write_entities_to_parquet( + entities, joinuri(working_dir, "outputs") + ) else: parquet_locations = {} - return parquet_locations, messages + return parquet_locations, get_parent(feedback_errors_uri) def process_legacy( self, + working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, rule_metadata: RuleMetadata, - cache_prefix: URI, submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[MutableMapping[EntityName, DataFrame], Messages]: + ) -> tuple[MutableMapping[EntityName, DataFrame], URI]: """Apply the data contract and the rules, create Spark `DataFrame`s from the entities and return the Spark entities and all generated messages. @@ -221,17 +222,19 @@ def process_legacy( category=DeprecationWarning, ) - entities, messages, successful = self.apply( - entity_locations, contract_metadata, rule_metadata, submission_info + entities, errors_uri, successful = self.apply( + working_dir, entity_locations, contract_metadata, rule_metadata, submission_info ) if not successful: - return {}, messages + return {}, errors_uri if self.__entity_type__ == DataFrame: - return entities, messages # type: ignore + return entities, errors_uri # type: ignore return ( - self.convert_entities_to_spark(entities, cache_prefix, _emit_deprecation_warning=False), - messages, + self.convert_entities_to_spark( + entities, joinuri(working_dir, "outputs"), _emit_deprecation_warning=False + ), + errors_uri, ) diff --git a/src/dve/core_engine/backends/base/contract.py b/src/dve/core_engine/backends/base/contract.py index 338bd9f..a431120 100644 --- a/src/dve/core_engine/backends/base/contract.py +++ b/src/dve/core_engine/backends/base/contract.py @@ -9,6 +9,11 @@ from pydantic import BaseModel from typing_extensions import Protocol +from dve.common.error_utils import ( + dump_processing_errors, + get_feedback_errors_uri, + get_processing_errors_uri, +) from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.base.reader import BaseFileReader from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport, render_error @@ -16,11 +21,13 @@ from dve.core_engine.backends.readers import get_reader from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful from dve.core_engine.backends.utilities import dedup_messages, stringify_model +from dve.core_engine.exceptions import CriticalProcessingError from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import ( URI, ArbitraryFunction, + DVEStageName, EntityLocations, EntityName, JSONDict, @@ -96,6 +103,10 @@ class BaseDataContract(Generic[EntityType], ABC): This is set and populated in `__init_subclass__` by identifying methods decorated with the '@reader_override' decorator, and is used in `read_entity_type`. + """ + __stage_name__: DVEStageName = "data_contract" + """ + The name of the data contract DVE stage for use in auditing and logging """ def __init_subclass__(cls, *_, **__) -> None: @@ -360,8 +371,13 @@ def read_raw_entities( @abstractmethod def apply_data_contract( - self, entities: Entities, contract_metadata: DataContractMetadata - ) -> tuple[Entities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: Entities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[Entities, URI, StageSuccessful]: """Apply the data contract to the raw entities, returning the validated entities and any messages. @@ -371,35 +387,59 @@ def apply_data_contract( raise NotImplementedError() def apply( - self, entity_locations: EntityLocations, contract_metadata: DataContractMetadata - ) -> tuple[Entities, Messages, StageSuccessful]: + self, + working_dir: URI, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[Entities, URI, StageSuccessful, URI]: """Read the entities from the provided locations according to the data contract, and return the validated entities and any messages. """ + feedback_errors_uri = get_feedback_errors_uri(working_dir, self.__stage_name__) + processing_errors_uri = get_processing_errors_uri(working_dir) entities, messages, successful = self.read_raw_entities(entity_locations, contract_metadata) if not successful: - return {}, messages, successful + dump_processing_errors( + working_dir, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while reading raw entities", + [msg.error_message for msg in messages], + ) + ], + ) + return {}, feedback_errors_uri, successful, processing_errors_uri try: - entities, contract_messages, successful = self.apply_data_contract( - entities, contract_metadata + entities, feedback_errors_uri, successful = self.apply_data_contract( + working_dir, entities, entity_locations, contract_metadata, key_fields ) - messages.extend(contract_messages) except Exception as err: # pylint: disable=broad-except successful = False new_messages = render_error( err, - "data contract", + self.__stage_name__, self.logger, ) - messages.extend(new_messages) + dump_processing_errors( + working_dir, + self.__stage_name__, + [ + CriticalProcessingError( + f"Issue occurred while applying {self.__stage_name__}", + [msg.error_message for msg in new_messages], + ) + ], + ) if contract_metadata.cache_originals: for entity_name in list(entities): entities[f"Original{entity_name}"] = entities[entity_name] - return entities, messages, successful + return entities, feedback_errors_uri, successful, processing_errors_uri def read_parquet(self, path: URI, **kwargs) -> EntityType: """Method to read parquet files from stringified parquet output diff --git a/src/dve/core_engine/backends/base/reference_data.py b/src/dve/core_engine/backends/base/reference_data.py index a9a68fa..5be0ec0 100644 --- a/src/dve/core_engine/backends/base/reference_data.py +++ b/src/dve/core_engine/backends/base/reference_data.py @@ -7,10 +7,30 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated, Literal +import dve.parser.file_handling as fh from dve.core_engine.backends.base.core import get_entity_type -from dve.core_engine.backends.exceptions import MissingRefDataEntity +from dve.core_engine.backends.exceptions import ( + MissingRefDataEntity, + RefdataLacksFileExtensionSupport, +) from dve.core_engine.backends.types import EntityType -from dve.core_engine.type_hints import EntityName +from dve.core_engine.type_hints import URI, EntityName +from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation +from dve.parser.file_handling.service import _get_implementation + +_FILE_EXTENSION_NAME: str = "_REFDATA_FILE_EXTENSION" +"""Name of attribute added to methods where they relate + to loading a particular reference file type.""" + + +def mark_refdata_file_extension(file_extension): + """Mark a method for loading a particular file extension""" + + def wrapper(func: Callable): + setattr(func, _FILE_EXTENSION_NAME, file_extension) + return func + + return wrapper class ReferenceTable(BaseModel, frozen=True): @@ -37,7 +57,12 @@ class ReferenceFile(BaseModel, frozen=True): type: Literal["filename"] """The object type.""" filename: str - """The path to the reference data (as Parquet) relative to the contract.""" + """The path to the reference data relative to the contract.""" + + @property + def file_extension(self) -> str: + """The file extension of the reference file""" + return fh.get_file_suffix(self.filename) # type: ignore class ReferenceURI(BaseModel, frozen=True): @@ -48,6 +73,11 @@ class ReferenceURI(BaseModel, frozen=True): uri: str """The absolute URI of the reference data (as Parquet).""" + @property + def file_extension(self) -> str: + """The file extension of the reference uri""" + return fh.get_file_suffix(self.uri) # type: ignore + ReferenceConfig = Union[ReferenceFile, ReferenceTable, ReferenceURI] """The config utilised to load the reference data""" @@ -71,6 +101,12 @@ class BaseRefDataLoader(Generic[EntityType], Mapping[EntityName, EntityType], AB A mapping between refdata config types and functions to call to load these configs into reference data entities """ + + __reader_functions__: ClassVar[dict[str, Callable]] = {} + """ + A mapping between file extensions and functions to load the file uris + into reference data entities + """ prefix: str = "refdata_" def __init_subclass__(cls, *_, **__) -> None: @@ -82,6 +118,9 @@ class variable for the subclass. if cls is not BaseRefDataLoader: cls.__entity_type__ = get_entity_type(cls, "BaseRefDataLoader") + # ensure that dicts are specific to each subclass - redefine rather + # than keep the same reference + cls.__reader_functions__ = {} cls.__step_functions__ = {} for method_name in dir(cls): @@ -92,19 +131,28 @@ class variable for the subclass. if method is None or not callable(method): continue + if ext := getattr(method, _FILE_EXTENSION_NAME, None): + cls.__reader_functions__[ext] = method + continue + type_hints = get_type_hints(method) if set(type_hints.keys()) != {"config", "return"}: continue config_type = type_hints["config"] if not issubclass(config_type, BaseModel): continue + cls.__step_functions__[config_type] = method # type: ignore # pylint: disable=unused-argument def __init__( - self, reference_entity_config: dict[EntityName, ReferenceConfig], **kwargs + self, + reference_entity_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: Optional[URI] = None, + **kwargs, ) -> None: self.reference_entity_config = reference_entity_config + self.dataset_config_uri = dataset_config_uri """ Configuration options for the reference data. This is likely to vary from backend to backend (e.g. might be locations and file types for @@ -119,15 +167,30 @@ def load_table(self, config: ReferenceTable) -> EntityType: """Load reference entity from a database table""" raise NotImplementedError() - @abstractmethod def load_file(self, config: ReferenceFile) -> EntityType: "Load reference entity from a relative file path" - raise NotImplementedError() + if not self.dataset_config_uri: + raise AttributeError("dataset_config_uri must be specified if using relative paths") + target_location = fh.build_relative_uri(self.dataset_config_uri, config.filename) + if isinstance(_get_implementation(self.dataset_config_uri), LocalFilesystemImplementation): + target_location = fh.file_uri_to_local_path(target_location).as_posix() + try: + impl = self.__reader_functions__[config.file_extension] + return impl(self, target_location) + except KeyError as exc: + raise RefdataLacksFileExtensionSupport(file_extension=config.file_extension) from exc - @abstractmethod def load_uri(self, config: ReferenceURI) -> EntityType: "Load reference entity from an absolute URI" - raise NotImplementedError() + if isinstance(_get_implementation(config.uri), LocalFilesystemImplementation): + target_location = fh.file_uri_to_local_path(config.uri).as_posix() + else: + target_location = config.uri + try: + impl = self.__reader_functions__[config.file_extension] + return impl(self, target_location) + except KeyError as exc: + raise RefdataLacksFileExtensionSupport(file_extension=config.file_extension) from exc def load_entity(self, entity_name: EntityName, config: ReferenceConfig) -> EntityType: """Load a reference entity given the reference config""" diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py index ef147b6..b862c27 100644 --- a/src/dve/core_engine/backends/base/rules.py +++ b/src/dve/core_engine/backends/base/rules.py @@ -9,6 +9,12 @@ from typing_extensions import Literal, Protocol, get_type_hints +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_feedback_errors, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.exceptions import render_error from dve.core_engine.backends.metadata.rules import ( @@ -37,8 +43,9 @@ TableUnion, ) from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful +from dve.core_engine.exceptions import CriticalProcessingError from dve.core_engine.loggers import get_logger -from dve.core_engine.type_hints import URI, EntityName, Messages, TemplateVariables +from dve.core_engine.type_hints import URI, DVEStageName, EntityName, Messages, TemplateVariables T_contra = TypeVar("T_contra", bound=AbstractStep, contravariant=True) T = TypeVar("T", bound=AbstractStep) @@ -81,6 +88,10 @@ class BaseStepImplementations(Generic[EntityType], ABC): # pylint: disable=too- This will be populated from the generic annotation at class creation time. + """ + __stage_name__: DVEStageName = "business_rules" + """ + The name of the business rules DVE stage for use in auditing and logging """ def __init_subclass__(cls, *_, **__) -> None: @@ -188,7 +199,7 @@ def evaluate(self, entities, *, config: AbstractStep) -> tuple[Messages, StageSu if success: success = False msg = f"Critical failure in rule {self._step_metadata_to_location(config)}" - self.logger.error(msg) + self.logger.exception(msg) self.logger.error(str(message)) return messages, success @@ -343,9 +354,14 @@ def notify(self, entities: Entities, *, config: Notification) -> Messages: """ + # pylint: disable=R0912,R0914 def apply_sync_filters( - self, entities: Entities, *filters: DeferredFilter - ) -> tuple[Messages, StageSuccessful]: + self, + working_directory: URI, + entities: Entities, + *filters: DeferredFilter, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[URI, StageSuccessful]: """Apply the synchronised filters, emitting appropriate error messages for any records which do not meet the conditions. @@ -355,109 +371,187 @@ def apply_sync_filters( """ filters_by_entity: dict[EntityName, list[DeferredFilter]] = defaultdict(list) + feedback_errors_uri = get_feedback_errors_uri(working_directory, self.__stage_name__) for rule in filters: filters_by_entity[rule.entity_name].append(rule) - messages: Messages = [] - for entity_name, filter_rules in filters_by_entity.items(): - self.logger.info(f"Applying filters to {entity_name}") - entity = entities[entity_name] - - filter_column_names: list[str] = [] - unmodified_entities = {entity_name: entity} - modified_entities = {entity_name: entity} - - for rule in filter_rules: - self.logger.info(f"Applying filter {rule.reporting.code}") - if rule.reporting.emit == "record_failure": - column_name = f"filter_{uuid4().hex}" - filter_column_names.append(column_name) - temp_messages, success = self.evaluate( - modified_entities, - config=ColumnAddition( - entity_name=entity_name, - column_name=column_name, - expression=rule.expression, - parent=rule.parent, - ), - ) - messages.extend(temp_messages) - if not success: - return messages, False - - temp_messages, success = self.evaluate( - modified_entities, - config=Notification( - entity_name=entity_name, - expression=f"NOT {column_name}", - excluded_columns=filter_column_names, - reporting=rule.reporting, - parent=rule.parent, - ), + with BackgroundMessageWriter( + working_directory=working_directory, + dve_stage=self.__stage_name__, + key_fields=key_fields, + logger=self.logger, + ) as msg_writer: + for entity_name, filter_rules in filters_by_entity.items(): + self.logger.info(f"Applying filters to {entity_name}") + entity = entities[entity_name] + + filter_column_names: list[str] = [] + unmodified_entities = {entity_name: entity} + modified_entities = {entity_name: entity} + + for rule in filter_rules: + self.logger.info(f"Applying filter {rule.reporting.code}") + if rule.reporting.emit == "record_failure": + column_name = f"filter_{uuid4().hex}" + filter_column_names.append(column_name) + temp_messages, success = self.evaluate( + modified_entities, + config=ColumnAddition( + entity_name=entity_name, + column_name=column_name, + expression=rule.expression, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while applying filter logic", + messages=[ + msg.error_message + for msg in temp_messages + if msg.error_message + ], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + temp_messages, success = self.evaluate( + modified_entities, + config=Notification( + entity_name=entity_name, + expression=f"NOT {column_name}", + excluded_columns=filter_column_names, + reporting=rule.reporting, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + self.logger.info( + f"Filter {rule.reporting.code} found {len(temp_messages)} issues" + ) + + else: + temp_messages, success = self.evaluate( + unmodified_entities, + config=Notification( + entity_name=entity_name, + expression=f"NOT ({rule.expression})", + reporting=rule.reporting, + parent=rule.parent, + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + self.logger.info( + f"Filter {rule.reporting.code} found {len(temp_messages)} issues" + ) + + if filter_column_names: + self.logger.info( + f"Filtering records from entity {entity_name} for error code {rule.reporting.code}" # pylint: disable=line-too-long ) - messages.extend(temp_messages) - if not success: - return messages, False - - else: - temp_messages, success = self.evaluate( - unmodified_entities, - config=Notification( - entity_name=entity_name, - expression=f"NOT ({rule.expression})", - reporting=rule.reporting, - parent=rule.parent, - ), + success_condition = " AND ".join( + [f"({c_name} IS NOT NULL AND {c_name})" for c_name in filter_column_names] ) - messages.extend(temp_messages) - if not success: - return messages, False - - self.logger.info(f"Filter {rule.reporting.code} found {len(temp_messages)} issues") - - if filter_column_names: - self.logger.info( - f"Filtering records from entity {entity_name} for error code {rule.reporting.code}" # pylint: disable=line-too-long - ) - success_condition = " AND ".join( - [f"({c_name} IS NOT NULL AND {c_name})" for c_name in filter_column_names] - ) - temp_messages, success = self.evaluate( - modified_entities, - config=ImmediateFilter( - entity_name=entity_name, - expression=success_condition, - parent=ParentMetadata( - rule="FilterStageRecordLevelFilterApplication", index=0, stage="Sync" - ), - ), - ) - messages.extend(temp_messages) - if not success: - return messages, False - - for index, filter_column_name in enumerate(filter_column_names): temp_messages, success = self.evaluate( modified_entities, - config=ColumnRemoval( + config=ImmediateFilter( entity_name=entity_name, - column_name=filter_column_name, + expression=success_condition, parent=ParentMetadata( - rule="FilterStageRecordLevelFilterColumnRemoval", - index=index, + rule="FilterStageRecordLevelFilterApplication", + index=0, stage="Sync", ), ), ) - messages.extend(temp_messages) if not success: - return messages, False - - entities.update(modified_entities) - - return messages, True - - def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messages: + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while filtering error records", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + for index, filter_column_name in enumerate(filter_column_names): + temp_messages, success = self.evaluate( + modified_entities, + config=ColumnRemoval( + entity_name=entity_name, + column_name=filter_column_name, + parent=ParentMetadata( + rule="FilterStageRecordLevelFilterColumnRemoval", + index=index, + stage="Sync", + ), + ), + ) + if not success: + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while generating FeedbackMessages", + [msg.error_message for msg in temp_messages], + ) + ], + ) + return processing_errors_uri, False + if temp_messages: + msg_writer.write_queue.put(temp_messages) + + entities.update(modified_entities) + + return feedback_errors_uri, True + + def apply_rules( + self, + working_directory: URI, + entities: Entities, + rule_metadata: RuleMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[URI, bool]: """Create rule definitions from the metadata for a given dataset and evaluate the impact on the provided entities, returning a deque of messages and altering the entities in-place. @@ -465,6 +559,7 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag """ self.logger.info("Applying business rules") rules_and_locals: Iterable[tuple[Rule, TemplateVariables]] + errors_uri = get_feedback_errors_uri(working_directory, self.__stage_name__) if rule_metadata.templating_strategy == "upfront": rules_and_locals = [] for rule, local_variables in rule_metadata: @@ -479,8 +574,7 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag else: rules_and_locals = rule_metadata - messages: Messages = [] - + pre_sync_messages: Messages = [] self.logger.info("Applying pre-sync steps") for rule, local_variables in rules_and_locals: for step in rule.pre_sync_steps: @@ -490,9 +584,29 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) stage_messages, success = self.evaluate(entities, config=step) - messages.extend(stage_messages) + # if failure, write out processing issues and all prior messages (so nothing lost) if not success: - return messages + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while applying pre filter steps", + [msg.error_message for msg in stage_messages], + ) + ], + ) + if pre_sync_messages: + dump_feedback_errors( + working_directory, self.__stage_name__, pre_sync_messages + ) + + return processing_errors_uri, False + # if not a failure, ensure we keep track of any informational messages + pre_sync_messages.extend(stage_messages) + # if all successful, ensure we write out all informational messages + if pre_sync_messages: + dump_feedback_errors(working_directory, self.__stage_name__, pre_sync_messages) sync_steps = [] for rule, local_variables in rules_and_locals: @@ -503,10 +617,15 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) sync_steps.append(step) - stage_messages, success = self.apply_sync_filters(entities, *sync_steps) - messages.extend(stage_messages) + # error writing handled in apply_sync_filters + errors_uri, success = self.apply_sync_filters( + working_directory, entities, *sync_steps, key_fields=key_fields + ) if not success: - return messages + return errors_uri, False + + post_sync_messages: Messages = [] + self.logger.info("Applying post-sync steps") self.logger.info("Applying post-sync steps") @@ -518,10 +637,29 @@ def apply_rules(self, entities: Entities, rule_metadata: RuleMetadata) -> Messag ) stage_messages, success = self.evaluate(entities, config=step) - messages.extend(stage_messages) if not success: - return messages - return messages + processing_errors_uri = dump_processing_errors( + working_directory, + self.__stage_name__, + [ + CriticalProcessingError( + "Issue occurred while applying post filter steps", + [msg.error_message for msg in stage_messages], + ) + ], + ) + if post_sync_messages: + dump_feedback_errors( + working_directory, self.__stage_name__, post_sync_messages + ) + + return processing_errors_uri, False + # if not a failure, ensure we keep track of any informational messages + post_sync_messages.extend(stage_messages) + # if all successful, ensure we write out all informational messages + if post_sync_messages: + dump_feedback_errors(working_directory, self.__stage_name__, post_sync_messages) + return errors_uri, True def read_parquet(self, path: URI, **kwargs) -> EntityType: """Method to read parquet files""" diff --git a/src/dve/core_engine/backends/base/utilities.py b/src/dve/core_engine/backends/base/utilities.py index 30efc74..f55bc88 100644 --- a/src/dve/core_engine/backends/base/utilities.py +++ b/src/dve/core_engine/backends/base/utilities.py @@ -5,8 +5,12 @@ from collections.abc import Sequence from typing import Optional +import pyarrow # type: ignore +import pyarrow.parquet as pq # type: ignore + from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import ExpressionArray, MultiExpression +from dve.parser.type_hints import URI BRACKETS = {"(": ")", "{": "}", "[": "]", "<": ">"} """A mapping of opening brackets to their closing counterpart.""" @@ -131,3 +135,12 @@ def _get_non_heterogenous_type(types: Sequence[type]) -> type: + f"union types (got {type_list!r}) but nullable types are okay" ) return type_list[0] + + +def check_if_parquet_file(file_location: URI) -> bool: + """Check if a file path is valid parquet""" + try: + pq.ParquetFile(file_location) + return True + except (pyarrow.ArrowInvalid, pyarrow.ArrowIOError): + return False diff --git a/src/dve/core_engine/backends/exceptions.py b/src/dve/core_engine/backends/exceptions.py index 279f4ce..8dd50ef 100644 --- a/src/dve/core_engine/backends/exceptions.py +++ b/src/dve/core_engine/backends/exceptions.py @@ -108,7 +108,7 @@ def get_joiner(self): return "required by" -class MissingRefDataEntity(MissingEntity): # pylint: disable=too-many-ancestors +class MissingRefDataEntity(MissingEntity, BackendErrorMixin): # pylint: disable=too-many-ancestors """An error to be emitted when a required refdata entity is missing.""" def get_message_preamble(self) -> str: @@ -166,6 +166,23 @@ def get_message_preamble(self) -> EntityName: return f"Reader does not support reading directly to entity type {self.entity_type!r}" +class RefdataLacksFileExtensionSupport(BackendError): + """An error raised when trying to load a refdata file where the loader + lacks support for the given file type + + """ + + def __init__(self, *args: object, file_extension: str) -> None: + super().__init__(*args) + self.file_extension = file_extension + """The file extension that is not supported directly by the + refdata loader""" + + def get_message_preamble(self) -> EntityName: + """Message for logging purposes""" + return f"Refdata loader does not support reading refdata from {self.file_extension} files" + + class EmptyFileError(ReaderErrorMixin, ValueError): """The read file was empty.""" diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 51017a5..b71be85 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -3,21 +3,32 @@ # pylint: disable=R0903 import logging from collections.abc import Iterator +from functools import partial +from multiprocessing import Pool, cpu_count from typing import Any, Optional from uuid import uuid4 import pandas as pd import polars as pl +import pyarrow.parquet as pq # type: ignore from duckdb import DuckDBPyConnection, DuckDBPyRelation from duckdb.typing import DuckDBPyType from polars.datatypes.classes import DataTypeClass as PolarsType from pydantic import BaseModel from pydantic.fields import ModelField +import dve.parser.file_handling as fh +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.contract import BaseDataContract -from dve.core_engine.backends.base.utilities import generate_error_casting_entity_message +from dve.core_engine.backends.base.utilities import ( + check_if_parquet_file, + generate_error_casting_entity_message, +) from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( - coerce_inferred_numpy_array_to_list, duckdb_read_parquet, duckdb_write_parquet, get_duckdb_type_from_annotation, @@ -28,8 +39,8 @@ from dve.core_engine.backends.types import StageSuccessful from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model from dve.core_engine.message import FeedbackMessage -from dve.core_engine.type_hints import URI, Messages -from dve.core_engine.validation import RowValidator +from dve.core_engine.type_hints import URI, EntityLocations +from dve.core_engine.validation import RowValidator, apply_row_validator_helper class PandasApplyHelper: @@ -49,7 +60,8 @@ def __call__(self, row: pd.Series): class DuckDBDataContract(BaseDataContract[DuckDBPyRelation]): """An implementation of a data contract in DuckDB. - This utilises the conversion from relation to pandas dataframe to apply the data contract. + This utilises pyarrow to distibute parquet data across python processes and + a background process to write error messages. """ @@ -71,8 +83,8 @@ def connection(self) -> DuckDBPyConnection: """The duckdb connection""" return self._connection - def _cache_records(self, relation: DuckDBPyRelation, cache_prefix: URI) -> URI: - chunk_uri = "/".join((cache_prefix.rstrip("/"), str(uuid4()))) + ".parquet" + def _cache_records(self, relation: DuckDBPyRelation, working_dir: URI) -> URI: + chunk_uri = "/".join((working_dir.rstrip("/"), str(uuid4()))) + ".parquet" self.write_parquet(entity=relation, target_location=chunk_uri) return chunk_uri @@ -98,77 +110,106 @@ def generate_ddb_cast_statement( return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"' return f'cast(NULL AS {dtype}) AS "{column_name}"' + # pylint: disable=R0914 def apply_data_contract( - self, entities: DuckDBEntities, contract_metadata: DataContractMetadata - ) -> tuple[DuckDBEntities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: DuckDBEntities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[DuckDBEntities, URI, StageSuccessful]: """Apply the data contract to the duckdb relations""" - all_messages: Messages = [] + self.logger.info("Applying data contracts") + feedback_errors_uri: URI = get_feedback_errors_uri(working_dir, "data_contract") + + # check if entities are valid parquet - if not, convert + for entity, entity_loc in entity_locations.items(): + if not check_if_parquet_file(entity_loc): + parquet_uri = self.write_parquet( + entities[entity], fh.joinuri(fh.get_parent(entity_loc), f"{entity}.parquet") + ) + entity_locations[entity] = parquet_uri successful = True - for entity_name, relation in entities.items(): - # get dtypes for all fields -> python data types or use with relation - entity_fields: dict[str, ModelField] = contract_metadata.schemas[entity_name].__fields__ - ddb_schema: dict[str, DuckDBPyType] = { - fld.name: get_duckdb_type_from_annotation(fld.annotation) - for fld in entity_fields.values() - } - polars_schema: dict[str, PolarsType] = { - fld.name: get_polars_type_from_annotation(fld.annotation) - for fld in entity_fields.values() - } - if relation_is_empty(relation): - self.logger.warning(f"+ Empty relation for {entity_name}") - empty_df = pl.DataFrame([], schema=polars_schema) # type: ignore # pylint: disable=W0612 - relation = self._connection.sql("select * from empty_df") - continue - - self.logger.info(f"+ Applying contract to: {entity_name}") - - row_validator = contract_metadata.validators[entity_name] - application_helper = PandasApplyHelper(row_validator) - self.logger.info("+ Applying data contract") - coerce_inferred_numpy_array_to_list(relation.df()).apply( - application_helper, axis=1 - ) # pandas uses eager evaluation so potential memory issue here? - self.logger.info( - f"Data contract found {len(application_helper.errors)} issues in {entity_name}" - ) - all_messages.extend(application_helper.errors) - - casting_statements = [ - ( - self.generate_ddb_cast_statement(column, dtype) - if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + + with BackgroundMessageWriter( + working_dir, "data_contract", key_fields=key_fields + ) as msg_writer: + for entity_name, relation in entities.items(): + # get dtypes for all fields -> python data types or use with relation + entity_fields: dict[str, ModelField] = contract_metadata.schemas[ + entity_name + ].__fields__ + ddb_schema: dict[str, DuckDBPyType] = { + fld.name: get_duckdb_type_from_annotation(fld.annotation) + for fld in entity_fields.values() + } + polars_schema: dict[str, PolarsType] = { + fld.name: get_polars_type_from_annotation(fld.annotation) + for fld in entity_fields.values() + } + if relation_is_empty(relation): + self.logger.warning(f"+ Empty relation for {entity_name}") + empty_df = pl.DataFrame([], schema=polars_schema) # type: ignore # pylint: disable=W0612 + relation = self._connection.sql("select * from empty_df") + continue + + self.logger.info(f"+ Applying contract to: {entity_name}") + + row_validator_helper = partial( + apply_row_validator_helper, + row_validator=contract_metadata.validators[entity_name], ) - for column, dtype in ddb_schema.items() - ] - try: - relation = relation.project(", ".join(casting_statements)) - except Exception as err: # pylint: disable=broad-except - successful = False - self.logger.error(f"Error in casting relation: {err}") - all_messages.append(generate_error_casting_entity_message(entity_name)) - continue - - if self.debug: - # count will force evaluation - only done in debug - pre_convert_row_count = relation.count("*").fetchone()[0] # type: ignore - self.logger.info(f"+ Converting to parquet: ({pre_convert_row_count} rows)") - else: - pre_convert_row_count = 0 - self.logger.info("+ Converting to parquet") - - entities[entity_name] = relation - if self.debug: - post_convert_row_count = entities[entity_name].count("*").fetchone()[0] # type: ignore # pylint:disable=line-too-long - self.logger.info(f"+ Converted to parquet: ({post_convert_row_count} rows)") - if post_convert_row_count != pre_convert_row_count: - raise ValueError( - f"Row count mismatch for {entity_name}" - f" ({pre_convert_row_count} vs {post_convert_row_count})" - ) - else: - self.logger.info("+ Converted to parquet") - return entities, all_messages, successful + batches = pq.ParquetFile(entity_locations[entity_name]).iter_batches(10000) + msg_count = 0 + with Pool(cpu_count() - 1) as pool: + for msgs in pool.imap_unordered(row_validator_helper, batches): + if msgs: + msg_writer.write_queue.put(msgs) + msg_count += len(msgs) + + self.logger.info(f"Data contract found {msg_count} issues in {entity_name}") + + casting_statements = [ + ( + self.generate_ddb_cast_statement(column, dtype) + if column in relation.columns + else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + ) + for column, dtype in ddb_schema.items() + ] + try: + relation = relation.project(", ".join(casting_statements)) + except Exception as err: # pylint: disable=broad-except + successful = False + self.logger.error(f"Error in casting relation: {err}") + dump_processing_errors( + working_dir, + "data_contract", + [generate_error_casting_entity_message(entity_name)], + ) + continue + + if self.debug: + # count will force evaluation - only done in debug + pre_convert_row_count = relation.count("*").fetchone()[0] # type: ignore + self.logger.info(f"+ Converting to parquet: ({pre_convert_row_count} rows)") + else: + pre_convert_row_count = 0 + self.logger.info("+ Converting to parquet") + + entities[entity_name] = relation + if self.debug: + post_convert_row_count = entities[entity_name].count("*").fetchone()[0] # type: ignore # pylint:disable=line-too-long + self.logger.info(f"+ Converted to parquet: ({post_convert_row_count} rows)") + if post_convert_row_count != pre_convert_row_count: + raise ValueError( + f"Row count mismatch for {entity_name}" + f" ({pre_convert_row_count} vs {post_convert_row_count})" + ) + else: + self.logger.info("+ Converted to parquet") + + return entities, feedback_errors_uri, successful diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index 5df3f6a..c10aed7 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -3,21 +3,15 @@ from typing import Optional from duckdb import DuckDBPyConnection, DuckDBPyRelation +from pyarrow import ipc # type: ignore -import dve.parser.file_handling as fh from dve.core_engine.backends.base.reference_data import ( BaseRefDataLoader, ReferenceConfigUnion, - ReferenceFile, ReferenceTable, - ReferenceURI, + mark_refdata_file_extension, ) from dve.core_engine.type_hints import EntityName -from dve.parser.file_handling.implementations.file import ( - LocalFilesystemImplementation, - file_uri_to_local_path, -) -from dve.parser.file_handling.service import _get_implementation from dve.parser.type_hints import URI @@ -35,7 +29,7 @@ def __init__( reference_entity_config: dict[EntityName, ReferenceConfigUnion], **kwargs, ) -> None: - super().__init__(reference_entity_config, **kwargs) + super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) if not self.connection: raise AttributeError("DuckDBConnection must be specified") @@ -44,19 +38,12 @@ def load_table(self, config: ReferenceTable) -> DuckDBPyRelation: """Load reference entity from a database table""" return self.connection.sql(f"select * from {config.fq_table_name}") - def load_file(self, config: ReferenceFile) -> DuckDBPyRelation: - "Load reference entity from a relative file path" - if not self.dataset_config_uri: - raise AttributeError("dataset_config_uri must be specified if using relative paths") - target_location = fh.build_relative_uri(self.dataset_config_uri, config.filename) - if isinstance(_get_implementation(self.dataset_config_uri), LocalFilesystemImplementation): - target_location = file_uri_to_local_path(target_location).as_posix() - return self.connection.read_parquet(target_location) - - def load_uri(self, config: ReferenceURI) -> DuckDBPyRelation: - "Load reference entity from an absolute URI" - if isinstance(_get_implementation(config.uri), LocalFilesystemImplementation): - target_location = file_uri_to_local_path(config.uri).as_posix() - else: - target_location = config.uri - return self.connection.read_parquet(target_location) + @mark_refdata_file_extension("parquet") + def load_parquet_file(self, uri: str) -> DuckDBPyRelation: + """Load a parquet file into a duckdb relation""" + return self.connection.read_parquet(uri) + + @mark_refdata_file_extension("arrow") + def load_arrow_file(self, uri: str) -> DuckDBPyRelation: + """Load an arrow ipc file into a duckdb relation""" + return self.connection.from_arrow(ipc.open_file(uri).read_all()) # type:ignore diff --git a/src/dve/core_engine/backends/implementations/spark/contract.py b/src/dve/core_engine/backends/implementations/spark/contract.py index afbd85e..d8078bd 100644 --- a/src/dve/core_engine/backends/implementations/spark/contract.py +++ b/src/dve/core_engine/backends/implementations/spark/contract.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterator +from itertools import islice from typing import Any, Optional from uuid import uuid4 @@ -11,6 +12,11 @@ from pyspark.sql.functions import col, lit from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType +from dve.common.error_utils import ( + BackgroundMessageWriter, + dump_processing_errors, + get_feedback_errors_uri, +) from dve.core_engine.backends.base.contract import BaseDataContract, reader_override from dve.core_engine.backends.base.utilities import generate_error_casting_entity_message from dve.core_engine.backends.exceptions import ( @@ -29,7 +35,7 @@ from dve.core_engine.backends.readers import CSVFileReader from dve.core_engine.backends.types import StageSuccessful from dve.core_engine.constants import ROWID_COLUMN_NAME -from dve.core_engine.type_hints import URI, EntityName, Messages +from dve.core_engine.type_hints import URI, EntityLocations, EntityName COMPLEX_TYPES: set[type[DataType]] = {StructType, ArrayType, MapType} """Spark types indicating complex types.""" @@ -61,12 +67,12 @@ def __init__( super().__init__(logger, **kwargs) - def _cache_records(self, dataframe: DataFrame, cache_prefix: URI) -> URI: + def _cache_records(self, dataframe: DataFrame, working_dir: URI) -> URI: """Write a chunk of records out to the cache dir, returning the path to the parquet file. """ - chunk_uri = "/".join((cache_prefix.rstrip("/"), str(uuid4()))) + ".parquet" + chunk_uri = "/".join((working_dir.rstrip("/"), str(uuid4()))) + ".parquet" dataframe.write.parquet(chunk_uri) return chunk_uri @@ -79,10 +85,17 @@ def create_entity_from_py_iterator( ) def apply_data_contract( - self, entities: SparkEntities, contract_metadata: DataContractMetadata - ) -> tuple[SparkEntities, Messages, StageSuccessful]: + self, + working_dir: URI, + entities: SparkEntities, + entity_locations: EntityLocations, + contract_metadata: DataContractMetadata, + key_fields: Optional[dict[str, list[str]]] = None, + ) -> tuple[SparkEntities, URI, StageSuccessful]: self.logger.info("Applying data contracts") - all_messages: Messages = [] + + entity_locations = {} if not entity_locations else entity_locations + feedback_errors_uri = get_feedback_errors_uri(working_dir, "data_contract") successful = True for entity_name, record_df in entities.items(): @@ -112,11 +125,18 @@ def apply_data_contract( record_df.rdd.map(lambda row: row.asDict(True)).map(row_validator) # .persist(storageLevel=StorageLevel.MEMORY_AND_DISK) ) - messages = validated.flatMap(lambda row: row[1]).filter(bool) - messages.cache() - self.logger.info(f"Data contract found {messages.count()} issues in {entity_name}") - all_messages.extend(messages.collect()) - messages.unpersist() + with BackgroundMessageWriter( + working_dir, "data_contract", key_fields, self.logger + ) as msg_writer: + messages = validated.flatMap(lambda row: row[1]).filter(bool).toLocalIterator() + msg_count = 0 + while True: + batch = list(islice(messages, 10000)) + if not batch: + break + msg_writer.write_queue.put(batch) + msg_count += len(batch) + self.logger.info(f"Data contract found {msg_count} issues in {entity_name}") try: record_df = record_df.select( @@ -129,7 +149,11 @@ def apply_data_contract( except Exception as err: # pylint: disable=broad-except successful = False self.logger.error(f"Error in converting to dataframe: {err}") - all_messages.append(generate_error_casting_entity_message(entity_name)) + dump_processing_errors( + working_dir, + "data_contract", + [generate_error_casting_entity_message(entity_name)], + ) continue if self.debug: @@ -155,7 +179,7 @@ def apply_data_contract( else: self.logger.info("+ Converted to Dataframe") - return entities, all_messages, successful + return entities, feedback_errors_uri, successful @reader_override(CSVFileReader) def read_csv_file( diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index de323d7..90ba4f6 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -5,13 +5,11 @@ from pyspark.sql import DataFrame, SparkSession -import dve.parser.file_handling as fh from dve.core_engine.backends.base.reference_data import ( BaseRefDataLoader, ReferenceConfig, - ReferenceFile, ReferenceTable, - ReferenceURI, + mark_refdata_file_extension, ) from dve.core_engine.type_hints import EntityName from dve.parser.type_hints import URI @@ -31,18 +29,14 @@ def __init__( reference_entity_config: dict[EntityName, ReferenceConfig], **kwargs, ) -> None: - super().__init__(reference_entity_config, **kwargs) + super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) if not self.spark: raise AttributeError("Spark session must be provided") def load_table(self, config: ReferenceTable) -> DataFrame: return self.spark.table(f"{config.fq_table_name}") - def load_file(self, config: ReferenceFile) -> DataFrame: - if not self.dataset_config_uri: - raise AttributeError("dataset_config_uri must be specified if using relative paths") - target_location = fh.build_relative_uri(self.dataset_config_uri, config.filename) - return self.spark.read.parquet(target_location) - - def load_uri(self, config: ReferenceURI) -> DataFrame: - return self.spark.read.parquet(config.uri) + @mark_refdata_file_extension("parquet") + def load_parquet_file(self, uri: str) -> DataFrame: + """Load a parquet file into a spark dataframe""" + return self.spark.read.parquet(uri) diff --git a/src/dve/core_engine/backends/utilities.py b/src/dve/core_engine/backends/utilities.py index bfa6f90..9261806 100644 --- a/src/dve/core_engine/backends/utilities.py +++ b/src/dve/core_engine/backends/utilities.py @@ -4,8 +4,7 @@ from dataclasses import is_dataclass from datetime import date, datetime, time from decimal import Decimal -from typing import GenericAlias # type: ignore -from typing import Any, ClassVar, Union +from typing import Any, ClassVar, GenericAlias, Union # type: ignore import polars as pl # type: ignore from polars.datatypes.classes import DataTypeClass as PolarsType diff --git a/src/dve/core_engine/engine.py b/src/dve/core_engine/engine.py index 87ab0b6..28a2ac5 100644 --- a/src/dve/core_engine/engine.py +++ b/src/dve/core_engine/engine.py @@ -1,12 +1,8 @@ """The core engine for the data validation engine.""" -import csv import json import logging -import shutil -from contextlib import ExitStack from pathlib import Path -from tempfile import NamedTemporaryFile from types import TracebackType from typing import Any, Optional, Union @@ -21,16 +17,9 @@ from dve.core_engine.configuration.v1 import V1EngineConfig from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger -from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import EngineRunValidation, SubmissionInfo -from dve.core_engine.type_hints import EntityName, JSONstring, Messages -from dve.parser.file_handling import ( - TemporaryPrefix, - get_resource_exists, - joinuri, - open_stream, - resolve_location, -) +from dve.core_engine.type_hints import EntityName, JSONstring +from dve.parser.file_handling import TemporaryPrefix, get_resource_exists, joinuri, resolve_location from dve.parser.type_hints import URI, Location @@ -47,7 +36,7 @@ class Config: # pylint: disable=too-few-public-methods """The backend configuration for the given run.""" dataset_config_uri: URI """The dischema location for the current run""" - output_prefix_uri: URI = Field(default_factory=lambda: Path("outputs").resolve().as_uri()) + output_prefix_uri: URI = Field(default_factory=lambda: Path("outputs").resolve().as_posix()) """The prefix for the parquet outputs.""" main_log: logging.Logger = Field(default_factory=lambda: get_logger("CoreEngine")) """The `logging.Logger instance for the data ingest process.""" @@ -129,14 +118,19 @@ def build( main_log.info(f"Debug mode: {debug}") if isinstance(dataset_config_path, Path): - dataset_config_uri = dataset_config_path.resolve().as_uri() + dataset_config_uri = dataset_config_path.resolve().as_posix() else: dataset_config_uri = dataset_config_path + if isinstance(output_prefix, Path): + output_prefix_uri = output_prefix.resolve().as_posix() + else: + output_prefix_uri = output_prefix + backend_config = V1EngineConfig.load(dataset_config_uri) self = cls( dataset_config_uri=dataset_config_uri, - output_prefix_uri=output_prefix, + output_prefix_uri=output_prefix_uri, main_log=main_log, cache_prefix_uri=cache_prefix, backend_config=backend_config, @@ -223,64 +217,14 @@ def _write_entity_outputs(self, entities: SparkEntities) -> SparkEntities: return output_entities - def _write_exception_report(self, messages: Messages) -> None: - """Write an exception report to the ouptut prefix. This is currently - a pipe-delimited CSV file containing all the messages emitted by the - pipeline. - - """ - # need to write using temp files and put to s3 with self.fs_impl? - - self.main_log.info(f"Creating exception report in the output dir: {self.output_prefix_uri}") - self.main_log.info("Splitting errors by category") - - contract_metadata = self.backend_config.get_contract_metadata() - with ExitStack() as file_contexts: - critical_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - critical_writer = csv.writer(critical_file, delimiter="|", lineterminator="\n") - - standard_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - standard_writer = csv.writer(standard_file, delimiter="|", lineterminator="\n") - - warning_file = file_contexts.enter_context(NamedTemporaryFile("r+", encoding="utf-8")) - warning_writer = csv.writer(warning_file, delimiter="|", lineterminator="\n") - - for message in messages: - if message.entity: - key_field = contract_metadata.reporting_fields.get(message.entity, None) - else: - key_field = None - message_row = message.to_row(key_field) - - if message.is_critical: - critical_writer.writerow(message_row) - elif message.is_informational: - warning_writer.writerow(message_row) - else: - standard_writer.writerow(message_row) - - output_uri = joinuri(self.output_prefix_uri, "pipeline.errors") - with open_stream(output_uri, "w", encoding="utf-8") as combined_error_file: - combined_error_file.write("|".join(FeedbackMessage.HEADER) + "\n") - for temp_error_file in [critical_file, standard_file, warning_file]: - temp_error_file.seek(0) - shutil.copyfileobj(temp_error_file, combined_error_file) - - def _write_outputs( - self, entities: SparkEntities, messages: Messages, verbose: bool = False - ) -> tuple[SparkEntities, Messages]: + def _write_outputs(self, entities: SparkEntities) -> SparkEntities: """Write the outputs from the pipeline, returning the written entities and messages. """ entities = self._write_entity_outputs(entities) - if verbose: - self._write_exception_report(messages) - else: - self.main_log.info("Skipping exception report") - - return entities, messages + return entities def _show_available_entities(self, entities: SparkEntities, *, verbose: bool = False) -> None: """Print current entities.""" @@ -301,11 +245,9 @@ def _show_available_entities(self, entities: SparkEntities, *, verbose: bool = F def run_pipeline( self, entity_locations: dict[EntityName, URI], - *, - verbose: bool = False, # pylint: disable=unused-argument submission_info: Optional[SubmissionInfo] = None, - ) -> tuple[SparkEntities, Messages]: + ) -> tuple[SparkEntities, URI]: """Run the pipeline, reading in the entities and applying validation and transformation rules, and then write the outputs. @@ -313,11 +255,11 @@ def run_pipeline( references should be valid after the pipeline context exits. """ - entities, messages = self.backend.process_legacy( + entities, errors_uri = self.backend.process_legacy( + self.output_prefix_uri, entity_locations, self.backend_config.get_contract_metadata(), self.backend_config.get_rule_metadata(), - self.cache_prefix, submission_info, ) - return self._write_outputs(entities, messages, verbose=verbose) + return self._write_outputs(entities), errors_uri diff --git a/src/dve/core_engine/message.py b/src/dve/core_engine/message.py index dd580c6..f2a4e52 100644 --- a/src/dve/core_engine/message.py +++ b/src/dve/core_engine/message.py @@ -1,5 +1,6 @@ """Functionality to represent messages.""" +# pylint: disable=C0103 import copy import datetime as dt import json @@ -17,10 +18,16 @@ from dve.core_engine.type_hints import ( EntityName, ErrorCategory, + ErrorCode, + ErrorLocation, + ErrorMessage, + ErrorType, FailureType, Messages, MessageTuple, Record, + ReportingField, + Status, ) from dve.parser.type_hints import FieldName @@ -82,6 +89,45 @@ class Config: # pylint: disable=too-few-public-methods arbitrary_types_allowed = True +# pylint: disable=R0902 +@dataclass +class UserMessage: + """The structure of the message that is used to populate the error report.""" + + Entity: Optional[str] + """The entity that the message pertains to (if applicable).""" + Key: Optional[str] + "The key field(s) in string format to allow users to identify the record" + FailureType: FailureType + "The type of failure" + Status: Status + "Indicating if an error or warning" + ErrorType: ErrorType + "The type of error" + ErrorLocation: ErrorLocation + "The source of the error" + ErrorMessage: ErrorMessage + "The error message to summarise the error" + ErrorCode: ErrorCode + "The error code of the error" + ReportingField: ReportingField + "The field(s) that the error relates to" + Value: Any + "The offending values" + Category: ErrorCategory + "The category of error" + + @property + def is_informational(self) -> bool: + "Indicates whether the message is a warning" + return self.Status == "informational" + + @property + def is_critical(self) -> bool: + "Indicates if the message relates to a processing issue" + return self.FailureType == "integrity" + + @dataclass(config=Config, eq=True) class FeedbackMessage: # pylint: disable=too-many-instance-attributes """Information which affects processing and needs to be feeded back.""" diff --git a/src/dve/core_engine/type_hints.py b/src/dve/core_engine/type_hints.py index 0be3763..afb6d9d 100644 --- a/src/dve/core_engine/type_hints.py +++ b/src/dve/core_engine/type_hints.py @@ -244,3 +244,12 @@ BinaryComparator = Callable[[Any, Any], bool] """Type hint for operator functions""" + +DVEStageName = Literal[ + "audit_received", + "file_transformation", + "data_contract", + "business_rules", + "error_report", + "pipeline", +] diff --git a/src/dve/core_engine/validation.py b/src/dve/core_engine/validation.py index 2be101e..f62309b 100644 --- a/src/dve/core_engine/validation.py +++ b/src/dve/core_engine/validation.py @@ -1,8 +1,11 @@ """XML schema/contract configuration.""" +# pylint: disable=E0611 import warnings +from itertools import chain from typing import Optional +from pyarrow.lib import RecordBatch # type: ignore from pydantic import ValidationError from pydantic.main import ModelMetaclass @@ -145,3 +148,13 @@ def handle_warnings(self, record, caught_warnings) -> list[FeedbackMessage]: ) ) return messages + + +def apply_row_validator_helper( + arrow_batch: RecordBatch, *, row_validator: RowValidator +) -> Messages: + """Helper to distribute data efficiently over python processes and then convert + to dictionaries for applying pydantic model""" + return list( + chain.from_iterable(msgs for _, msgs in map(row_validator, arrow_batch.to_pylist())) + ) diff --git a/src/dve/parser/file_handling/implementations/file.py b/src/dve/parser/file_handling/implementations/file.py index eeed3de..76d8b58 100644 --- a/src/dve/parser/file_handling/implementations/file.py +++ b/src/dve/parser/file_handling/implementations/file.py @@ -9,7 +9,7 @@ from typing_extensions import Literal -from dve.parser.exceptions import FileAccessError, UnsupportedSchemeError +from dve.parser.exceptions import FileAccessError from dve.parser.file_handling.helpers import parse_uri from dve.parser.file_handling.implementations.base import BaseFilesystemImplementation from dve.parser.type_hints import URI, NodeType, PathStr, Scheme @@ -20,11 +20,7 @@ def file_uri_to_local_path(uri: URI) -> Path: """Resolve a `file://` URI to a local filesystem path.""" - scheme, hostname, path = parse_uri(uri) - if scheme not in FILE_URI_SCHEMES: # pragma: no cover - raise UnsupportedSchemeError( - f"Local filesystem must use an allowed file URI scheme, got {scheme!r}" - ) + _, hostname, path = parse_uri(uri) path = unquote(path) # Unfortunately Windows is awkward. @@ -54,7 +50,7 @@ def _path_to_uri(self, path: Path) -> URI: easier to create implementations for 'file-like' protocols. """ - return path.as_uri() + return path.as_posix() @staticmethod def _handle_error( diff --git a/src/dve/parser/file_handling/service.py b/src/dve/parser/file_handling/service.py index 0422b4c..9ee9d9f 100644 --- a/src/dve/parser/file_handling/service.py +++ b/src/dve/parser/file_handling/service.py @@ -4,7 +4,7 @@ """ -# pylint: disable=logging-not-lazy +# pylint: disable=logging-not-lazy, unidiomatic-typecheck, protected-access import hashlib import platform import shutil @@ -286,6 +286,7 @@ def create_directory(target_uri: URI): return +# pylint: disable=too-many-branches def _transfer_prefix( source_prefix: URI, target_prefix: URI, overwrite: bool, action: Literal["copy", "move"] ): @@ -296,26 +297,37 @@ def _transfer_prefix( if action not in ("move", "copy"): # pragma: no cover raise ValueError(f"Unsupported action {action!r}, expected one of: 'copy', 'move'") - if not source_prefix.endswith("/"): - source_prefix += "/" - if not target_prefix.endswith("/"): - target_prefix += "/" - source_uris: list[URI] = [] target_uris: list[URI] = [] source_impl = _get_implementation(source_prefix) target_impl = _get_implementation(target_prefix) + if type(source_impl) == LocalFilesystemImplementation: + source_prefix = source_impl._path_to_uri(source_impl._uri_to_path(source_prefix)) + + if type(target_impl) == LocalFilesystemImplementation: + target_prefix = target_impl._path_to_uri(target_impl._uri_to_path(target_prefix)) + + if not source_prefix.endswith("/"): + source_prefix += "/" + if not target_prefix.endswith("/"): + target_prefix += "/" + for source_uri, node_type in source_impl.iter_prefix(source_prefix, True): if node_type != "resource": continue if not source_uri.startswith(source_prefix): # pragma: no cover - raise FileAccessError( - f"Listed URI ({source_uri!r}) not relative to source prefix " - + f"({source_prefix!r})" - ) + if type(_get_implementation(source_uri)) == LocalFilesystemImplementation: + # Due to local file systems having issues with local file scheme, + # stripping this check off + pass + else: + raise FileAccessError( + f"Listed URI ({source_uri!r}) not relative to source prefix " + + f"({source_prefix!r})" + ) path_within_prefix = source_uri[len(source_prefix) :] target_uri = target_prefix + path_within_prefix @@ -359,11 +371,11 @@ def move_prefix(source_prefix: URI, target_prefix: URI, overwrite: bool = False) def resolve_location(filename_or_url: Location) -> URI: """Resolve a union of filename and URI to a URI.""" if isinstance(filename_or_url, Path): - return filename_or_url.expanduser().resolve().as_uri() + return filename_or_url.expanduser().resolve().as_posix() parsed_url = urlparse(filename_or_url) if parsed_url.scheme == "file": # Passed a URL as a file. - return file_uri_to_local_path(filename_or_url).as_uri() + return file_uri_to_local_path(filename_or_url).as_posix() if platform.system() != "Windows": # On Linux, a filesystem path will never present with a scheme. diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 45d2261..7a597cd 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Optional +from dve.common.error_utils import dump_processing_errors from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( duckdb_get_entity_count, duckdb_write_parquet, @@ -17,7 +18,6 @@ from dve.parser.file_handling.service import _get_implementation from dve.pipeline.duckdb_pipeline import DDBDVEPipeline from dve.pipeline.utils import SubmissionStatus -from dve.reporting.utils import dump_processing_errors @duckdb_get_entity_count @@ -75,7 +75,7 @@ def apply_data_contract( self._logger.exception("Apply data contract raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), - "contract", + "data_contract", [CriticalProcessingError.from_exception(exc)], ) self._audit_tables.mark_failed(submissions=[submission_info.submission_id]) @@ -149,12 +149,12 @@ def run_pipeline( if sub_stats: self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats]) except Exception as err: # pylint: disable=W0718 - self._logger.error( - f"During processing of submission_id: {sub_id}, this exception was raised: {err}" + self._logger.exception( + f"During processing of submission_id: {sub_id}, this exception was raised:" ) dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), - "run_pipeline", + "pipeline", [CriticalProcessingError.from_exception(err)], ) self._audit_tables.mark_failed(submissions=[sub_id]) diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 366b590..6aa4f41 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -16,6 +16,12 @@ from pydantic import validate_arguments import dve.reporting.excel_report as er +from dve.common.error_utils import ( + dump_feedback_errors, + dump_processing_errors, + get_feedback_errors_uri, + load_feedback_messages, +) from dve.core_engine.backends.base.auditing import BaseAuditingManager from dve.core_engine.backends.base.contract import BaseDataContract from dve.core_engine.backends.base.core import EntityManager @@ -29,13 +35,12 @@ from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import SubmissionInfo, SubmissionStatisticsRecord -from dve.core_engine.type_hints import URI, FileURI, InfoURI +from dve.core_engine.type_hints import URI, DVEStageName, FileURI, InfoURI from dve.parser import file_handling as fh from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation from dve.parser.file_handling.service import _get_implementation from dve.pipeline.utils import SubmissionStatus, deadletter_file, load_config, load_reader from dve.reporting.error_report import ERROR_SCHEMA, calculate_aggregates -from dve.reporting.utils import dump_feedback_errors, dump_processing_errors PERMISSIBLE_EXCEPTIONS: tuple[type[Exception]] = ( FileNotFoundError, # type: ignore @@ -108,7 +113,9 @@ def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() - def get_submission_status(self, step_name: str, submission_id: str) -> SubmissionStatus: + def get_submission_status( + self, step_name: DVEStageName, submission_id: str + ) -> SubmissionStatus: """Determine submission status of a submission if not explicitly given""" if not (submission_status := self._audit_tables.get_submission_status(submission_id)): self._logger.warning( @@ -403,7 +410,7 @@ def apply_data_contract( self._logger.info(f"Applying data contract to {submission_info.submission_id}") if not submission_status: submission_status = self.get_submission_status( - "contract", submission_info.submission_id + "data_contract", submission_info.submission_id ) if not self.processed_files_path: raise AttributeError("processed files path not provided") @@ -411,35 +418,36 @@ def apply_data_contract( if not self.rules_path: raise AttributeError("rules path not provided") - read_from = fh.joinuri( - self.processed_files_path, submission_info.submission_id, "transform/" - ) - write_to = fh.joinuri(self.processed_files_path, submission_info.submission_id, "contract/") + working_dir = fh.joinuri(self.processed_files_path, submission_info.submission_id) + + read_from = fh.joinuri(working_dir, "transform/") + write_to = fh.joinuri(working_dir, "data_contract/") + + fh.create_directory(write_to) # simply for local file systems _, config, model_config = load_config(submission_info.dataset_id, self.rules_path) entities = {} + entity_locations = {} for path, _ in fh.iter_prefix(read_from): + entity_locations[fh.get_file_name(path)] = path entities[fh.get_file_name(path)] = self.data_contract.read_parquet(path) - entities, messages, _success = self.data_contract.apply_data_contract( # type: ignore - entities, config.get_contract_metadata() + key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} + + entities, feedback_errors_uri, _success = self.data_contract.apply_data_contract( # type: ignore + working_dir, entities, entity_locations, config.get_contract_metadata(), key_fields ) entitity: self.data_contract.__entity_type__ # type: ignore for entity_name, entitity in entities.items(): self.data_contract.write_parquet(entitity, fh.joinuri(write_to, entity_name)) - key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} - if messages: - dump_feedback_errors( - fh.joinuri(self.processed_files_path, submission_info.submission_id), - "contract", - messages, - key_fields=key_fields, - ) + validation_failed: bool = False + if fh.get_resource_exists(feedback_errors_uri): + messages = load_feedback_messages(feedback_errors_uri) - validation_failed = any(not rule_message.is_informational for rule_message in messages) + validation_failed = any(not user_message.is_informational for user_message in messages) if validation_failed: submission_status.validation_failed = True @@ -463,7 +471,7 @@ def data_contract_step( sub_status = ( sub_status if sub_status - else self.get_submission_status("contract", info.submission_id) + else self.get_submission_status("data_contract", info.submission_id) ) dc_futures.append( ( @@ -490,7 +498,7 @@ def data_contract_step( self._logger.exception("Data Contract raised exception:") dump_processing_errors( fh.joinuri(self.processed_files_path, sub_info.submission_id), - "contract", + "data_contract", [CriticalProcessingError.from_exception(exc)], ) sub_status.processing_failed = True @@ -518,7 +526,7 @@ def data_contract_step( def apply_business_rules( self, submission_info: SubmissionInfo, submission_status: Optional[SubmissionStatus] = None - ): + ) -> tuple[SubmissionInfo, SubmissionStatus]: """Apply the business rules to a given submission, the submission may have failed at the data_contract step so this should be passed in as a bool """ @@ -541,11 +549,16 @@ def apply_business_rules( raise AttributeError("step implementations has not been provided.") _, config, model_config = load_config(submission_info.dataset_id, self.rules_path) + working_directory: URI = fh.joinuri( + self._processed_files_path, submission_info.submission_id + ) ref_data = config.get_reference_data_config() rules = config.get_rule_metadata() reference_data = self._reference_data_loader(ref_data) # type: ignore entities = {} - contract = fh.joinuri(self.processed_files_path, submission_info.submission_id, "contract") + contract = fh.joinuri( + self.processed_files_path, submission_info.submission_id, "data_contract" + ) for parquet_uri, _ in fh.iter_prefix(contract): file_name = fh.get_file_name(parquet_uri) @@ -562,17 +575,13 @@ def apply_business_rules( entity_manager = EntityManager(entities=entities, reference_data=reference_data) - rule_messages = self.step_implementations.apply_rules(entity_manager, rules) # type: ignore key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} - if rule_messages: - dump_feedback_errors( - fh.joinuri(self.processed_files_path, submission_info.submission_id), - "business_rules", - rule_messages, - key_fields, - ) + self.step_implementations.apply_rules(working_directory, entity_manager, rules, key_fields) # type: ignore + rule_messages = load_feedback_messages( + get_feedback_errors_uri(working_directory, "business_rules") + ) submission_status.validation_failed = ( any(not rule_message.is_informational for rule_message in rule_messages) or submission_status.validation_failed @@ -707,12 +716,12 @@ def _get_error_dataframes(self, submission_id: str): errors_dfs = [pl.DataFrame([], schema=ERROR_SCHEMA)] # type: ignore for file, _ in fh.iter_prefix(path): - if fh.get_file_suffix(file) != "json": + if fh.get_file_suffix(file) != "jsonl": continue with fh.open_stream(file) as f: errors = None try: - errors = json.load(f) + errors = [json.loads(err) for err in f.readlines()] except UnicodeDecodeError: self._logger.exception(f"Error reading file: {file}") continue diff --git a/src/dve/reporting/error_report.py b/src/dve/reporting/error_report.py index ba4a4ac..8852fcb 100644 --- a/src/dve/reporting/error_report.py +++ b/src/dve/reporting/error_report.py @@ -1,15 +1,14 @@ """Error report generation""" -import datetime as dt import json from collections import deque from functools import partial from multiprocessing import Pool, cpu_count -from typing import Union import polars as pl from polars import DataFrame, LazyFrame, Utf8, col # type: ignore +from dve.common.error_utils import conditional_cast from dve.core_engine.message import FeedbackMessage from dve.parser.file_handling.service import open_stream @@ -49,22 +48,6 @@ def get_error_codes(error_code_path: str) -> LazyFrame: return pl.DataFrame(df_lists).lazy() # type: ignore -def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]: - """Determines what to do with a value coming back from the error list""" - if isinstance(value, list): - casts = [ - conditional_cast(val, primary_keys, value_separator) for val in value - ] # type: ignore - return value_separator.join( - [f"{pk}: {id}" if pk else "" for pk, id in zip(primary_keys, casts)] - ) - if isinstance(value, dt.date): - return value.isoformat() - if isinstance(value, dict): - return "" - return str(value) - - def _convert_inner_dict(error: FeedbackMessage, key_fields): return { key: ( diff --git a/src/dve/reporting/utils.py b/src/dve/reporting/utils.py deleted file mode 100644 index 1cb0b45..0000000 --- a/src/dve/reporting/utils.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Utilities to support reporting""" - -import json -from typing import Optional - -import dve.parser.file_handling as fh -from dve.core_engine.exceptions import CriticalProcessingError -from dve.core_engine.type_hints import URI, Messages -from dve.reporting.error_report import conditional_cast - - -def dump_feedback_errors( - working_folder: URI, - step_name: str, - messages: Messages, - key_fields: Optional[dict[str, list[str]]] = None, -): - """Write out captured feedback error messages.""" - if not working_folder: - raise AttributeError("processed files path not passed") - - if not key_fields: - key_fields = {} - - errors = fh.joinuri(working_folder, "errors", f"{step_name}_errors.json") - processed = [] - - for message in messages: - if message.original_entity is not None: - primary_keys = key_fields.get(message.original_entity, []) - elif message.entity is not None: - primary_keys = key_fields.get(message.entity, []) - else: - primary_keys = [] - - error = message.to_dict( - key_field=primary_keys, - value_separator=" -- ", - max_number_of_values=10, - record_converter=None, - ) - error["Key"] = conditional_cast(error["Key"], primary_keys, value_separator=" -- ") - processed.append(error) - - with fh.open_stream(errors, "a") as f: - json.dump( - processed, - f, - default=str, - ) - - -def dump_processing_errors( - working_folder: URI, step_name: str, errors: list[CriticalProcessingError] -): - """Write out critical processing errors""" - if not working_folder: - raise AttributeError("processed files path not passed") - if not step_name: - raise AttributeError("step name not passed") - if not errors: - raise AttributeError("errors list not passed") - - error_file: URI = fh.joinuri(working_folder, "processing_errors", "processing_errors.json") - processed = [] - - for error in errors: - processed.append( - { - "step_name": step_name, - "error_location": "processing", - "error_level": "integrity", - "error_message": error.error_message, - "error_traceback": error.messages, - } - ) - - with fh.open_stream(error_file, "a") as f: - json.dump( - processed, - f, - default=str, - ) diff --git a/tests/features/steps/utilities.py b/tests/features/steps/utilities.py index 30ebe16..aa9adc1 100644 --- a/tests/features/steps/utilities.py +++ b/tests/features/steps/utilities.py @@ -28,7 +28,6 @@ ] SERVICE_TO_STORAGE_PATH_MAPPING: Dict[str, str] = { "file_transformation": "transform", - "data_contract": "contract", } @@ -49,12 +48,12 @@ def load_errors_from_service(processing_folder: Path, service: str) -> pl.DataFr err_location = Path( processing_folder, "errors", - f"{SERVICE_TO_STORAGE_PATH_MAPPING.get(service, service)}_errors.json", + f"{SERVICE_TO_STORAGE_PATH_MAPPING.get(service, service)}_errors.jsonl", ) msgs = [] try: with open(err_location) as errs: - msgs = json.load(errs) + msgs = [json.loads(err) for err in errs.readlines()] except FileNotFoundError: pass diff --git a/tests/fixtures.py b/tests/fixtures.py index 8a9a147..457cf42 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -113,12 +113,11 @@ def spark_test_database(spark: SparkSession) -> Iterator[str]: -@pytest.fixture() +@pytest.fixture(scope="function") def temp_ddb_conn() -> Iterator[Tuple[Path, DuckDBPyConnection]]: """Temp DuckDB directory for the database""" - db = uuid4().hex + db = f"dve_{uuid4().hex}" with tempfile.TemporaryDirectory(prefix="ddb_audit_testing") as tmp: db_file = Path(tmp, db + ".duckdb") conn = connect(database=db_file, read_only=False) - yield db_file, conn diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py index 61920c2..442e4a0 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py @@ -14,8 +14,12 @@ from dve.core_engine.backends.implementations.duckdb.readers.xml import DuckDBXMLStreamReader from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator +from dve.parser.file_handling import get_resource_exists, joinuri +from dve.parser.file_handling.service import get_parent +from dve.common.error_utils import load_feedback_messages from tests.test_core_engine.test_backends.fixtures import ( nested_all_string_parquet, simple_all_string_parquet, @@ -83,15 +87,16 @@ def test_duckdb_data_contract_csv(temp_csv_file): header=True, delim=",", connection=connection ).read_to_entity_type(DuckDBPyRelation, str(uri), "test_ds", stringify_model(mdl)) } + entity_locations: Dict[str, URI] = {"test_ds": str(uri)} data_contract: DuckDBDataContract = DuckDBDataContract(connection) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta) rel: DuckDBPyRelation = entities.get("test_ds") assert dict(zip(rel.columns, rel.dtypes)) == { fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) for fld in mdl.__fields__.values() } - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert stage_successful @@ -150,6 +155,11 @@ def test_duckdb_data_contract_xml(temp_xml_file): ddb_connection=connection, root_tag="root", record_tag="ClassData" ).read_to_relation(str(uri), "class_info", class_model), } + entity_locations: dict[str, URI] = {} + for entity, rel in entities.items(): + loc: URI = joinuri(get_parent(uri.as_posix()), f"{entity}.parquet") + rel.write_parquet(loc, compression="snappy") + entity_locations[entity] = loc dc_meta = DataContractMetadata( reader_metadata={ @@ -178,7 +188,7 @@ def test_duckdb_data_contract_xml(temp_xml_file): ) data_contract: DuckDBDataContract = DuckDBDataContract(connection) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta) header_rel: DuckDBPyRelation = entities.get("test_header") header_expected_schema: Dict[str, DuckDBPyType] = { fld.name: get_duckdb_type_from_annotation(fld.type_) @@ -189,7 +199,7 @@ def test_duckdb_data_contract_xml(temp_xml_file): for fld in class_model.__fields__.values() } class_data_rel: DuckDBPyRelation = entities.get("test_class_info") - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert header_rel.count("*").fetchone()[0] == 1 assert dict(zip(header_rel.columns, header_rel.dtypes)) == header_expected_schema assert class_data_rel.count("*").fetchone()[0] == 2 @@ -237,9 +247,9 @@ def test_ddb_data_contract_read_and_write_basic_parquet( reporting_fields={"simple_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"simple_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["simple_model"].count("*").fetchone()[0] == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("simple_model_output.parquet") @@ -296,9 +306,9 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet): reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["nested_model"].count("*").fetchone()[0] == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("nested_model_output.parquet") @@ -353,12 +363,13 @@ def test_duckdb_data_contract_custom_error_details(nested_all_string_parquet_w_e reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful + messages: list[UserMessage] = [msg for msg in load_feedback_messages(feedback_errors_uri)] assert len(messages) == 2 - messages = sorted(messages, key= lambda x: x.error_code) - assert messages[0].error_code == "SUBFIELDTESTIDBAD" - assert messages[0].error_message == "subfield id is invalid: subfield.id - WRONG" - assert messages[1].error_code == "TESTIDBAD" - assert messages[1].error_message == "id is invalid: id - WRONG" - assert messages[1].entity == "test_rename" \ No newline at end of file + messages = sorted(messages, key= lambda x: x.ErrorCode) + assert messages[0].ErrorCode == "SUBFIELDTESTIDBAD" + assert messages[0].ErrorMessage == "subfield id is invalid: subfield.id - WRONG" + assert messages[1].ErrorCode == "TESTIDBAD" + assert messages[1].ErrorMessage == "id is invalid: id - WRONG" + assert messages[1].Entity == "test_rename" \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py new file mode 100644 index 0000000..7ae4858 --- /dev/null +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py @@ -0,0 +1,128 @@ +from pathlib import Path +import shutil +from uuid import uuid4 + +import pytest +from dve.core_engine.backends.exceptions import MissingRefDataEntity +from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader +from dve.core_engine.backends.base.core import EntityManager +from dve.core_engine.backends.base.reference_data import ReferenceFile, ReferenceTable, ReferenceURI + +from tempfile import TemporaryDirectory + +from tests.conftest import get_test_file_path + +@pytest.fixture(scope="module") +def temp_working_dir(): + with TemporaryDirectory(prefix="refdata_test") as tmp: + refdata_path = get_test_file_path("movies/refdata") + shutil.copytree(refdata_path.as_posix(), tmp, dirs_exist_ok=True) + yield tmp + +@pytest.fixture(scope="function") +def ddb_refdata_loader(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn + DuckDBRefDataLoader.connection = conn + DuckDBRefDataLoader.dataset_config_uri = temp_working_dir + yield DuckDBRefDataLoader, temp_working_dir + +@pytest.fixture(scope="function") +def ddb_refdata_table(ddb_refdata_loader): + refdata_loader, _ = ddb_refdata_loader + schema = "dve_" + uuid4().hex + tbl = "movies_sequels" + refdata_loader.connection.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") + refdata_loader.connection.read_parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).to_table(f"{schema}.{tbl}") + yield schema, tbl + refdata_loader.connection.sql(f"DROP TABLE IF EXISTS {schema}.{tbl}") + refdata_loader.connection.sql(f"DROP SCHEMA IF EXISTS {schema}") + +def test_load_arrow_file(ddb_refdata_loader): + refdata_loader, _ = ddb_refdata_loader + config = { + "test_refdata": ReferenceFile(type="filename", + filename="./movies_sequels.arrow") + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + + test = duckdb_refdata_loader.load_file(config.get("test_refdata")) + + assert test.shape == (3, 3) + +def test_load_parquet_file(ddb_refdata_loader): + refdata_loader, _ = ddb_refdata_loader + config = { + "test_refdata": ReferenceFile(type="filename", + filename="./movies_sequels.parquet") + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + + test = duckdb_refdata_loader.load_file(config.get("test_refdata")) + + assert test.shape == (2, 3) + +def test_load_uri_parquet(ddb_refdata_loader): + refdata_dir: Path + refdata_loader, refdata_dir = ddb_refdata_loader + config = { + "test_refdata": ReferenceURI(type="uri", + uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + + test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) + + assert test.shape == (2, 3) + +def test_load_uri_arrow(ddb_refdata_loader): + refdata_loader, refdata_dir = ddb_refdata_loader + config = { + "test_refdata": ReferenceURI(type="uri", + uri=Path(refdata_dir).joinpath("movies_sequels.arrow").as_posix()) + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + + test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) + + assert test.shape == (3, 3) + +def test_table_read(ddb_refdata_loader, ddb_refdata_table): + refdata_loader, _ = ddb_refdata_loader + db, tbl = ddb_refdata_table + config = { + "test_refdata": ReferenceTable(type="table", + table_name=tbl, + database=db) + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + + test = duckdb_refdata_loader.load_table(config.get("test_refdata")) + + assert test.shape == (2, 3) + +def test_via_entity_manager(ddb_refdata_loader, ddb_refdata_table): + refdata_loader, refdata_dir = ddb_refdata_loader + db, tbl = ddb_refdata_table + config = { + "test_refdata_file": ReferenceFile(type="filename", + filename="./movies_sequels.arrow"), + "test_refdata_uri": ReferenceURI(type="uri", + uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + "test_refdata_table": ReferenceTable(type="table", + table_name=tbl, + database=db) + } + em = EntityManager({}, reference_data=refdata_loader(config)) + assert em.get("refdata_test_refdata_file").shape == (3, 3) + assert em.get("refdata_test_refdata_uri").shape == (2, 3) + assert em.get("refdata_test_refdata_table").shape == (2, 3) + +def test_refdata_error(ddb_refdata_loader): + refdata_loader, refdata_dir = ddb_refdata_loader + config = { + "test_refdata_file": ReferenceFile(type="filename", + filename="./movies_sequels.arrow") + } + duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + with pytest.raises(MissingRefDataEntity): + duckdb_refdata_loader["missing_refdata"] \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py index 789ca1a..921c9be 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py @@ -16,8 +16,11 @@ from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig +from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator +from dve.parser.file_handling.service import get_parent, get_resource_exists +from dve.common.error_utils import load_feedback_messages from tests.test_core_engine.test_backends.fixtures import ( nested_all_string_parquet, nested_all_string_parquet_w_errors, @@ -67,9 +70,9 @@ def test_spark_data_contract_read_and_write_basic_parquet( reporting_fields={"simple_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"simple_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["simple_model"].count() == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("simple_model_output.parquet") @@ -140,9 +143,9 @@ def test_spark_data_contract_read_nested_parquet(nested_all_string_parquet): reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful - assert len(messages) == 0 + assert not get_resource_exists(feedback_errors_uri) assert entities["nested_model"].count() == 2 # check writes entity to parquet output_path: Path = Path(parquet_uri).parent.joinpath("nested_model_output.parquet") @@ -227,14 +230,15 @@ def test_spark_data_contract_custom_error_details(nested_all_string_parquet_w_er reporting_fields={"nested_model": ["id"]}, ) - entities, messages, stage_successful = data_contract.apply_data_contract(entities, dc_meta) + entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(parquet_uri), entities, {"nested_model": parquet_uri}, dc_meta) assert stage_successful + messages: list[UserMessage] = [msg for msg in load_feedback_messages(feedback_errors_uri)] assert len(messages) == 2 - messages = sorted(messages, key= lambda x: x.error_code) - assert messages[0].error_code == "SUBFIELDTESTIDBAD" - assert messages[0].error_message == "subfield id is invalid: subfield.id - WRONG" - assert messages[1].error_code == "TESTIDBAD" - assert messages[1].error_message == "id is invalid: id - WRONG" - assert messages[1].entity == "test_rename" + messages = sorted(messages, key= lambda x: x.ErrorCode) + assert messages[0].ErrorCode == "SUBFIELDTESTIDBAD" + assert messages[0].ErrorMessage == "subfield id is invalid: subfield.id - WRONG" + assert messages[1].ErrorCode == "TESTIDBAD" + assert messages[1].ErrorMessage == "id is invalid: id - WRONG" + assert messages[1].Entity == "test_rename" \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py new file mode 100644 index 0000000..b50b9bb --- /dev/null +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py @@ -0,0 +1,102 @@ +from pathlib import Path +import shutil + +import pytest +from dve.core_engine.backends.exceptions import MissingRefDataEntity, RefdataLacksFileExtensionSupport +from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader +from dve.core_engine.backends.base.core import EntityManager +from dve.core_engine.backends.base.reference_data import ReferenceFile, ReferenceTable, ReferenceURI + +from tempfile import TemporaryDirectory + +from tests.conftest import get_test_file_path + +@pytest.fixture(scope="module") +def temp_working_dir(): + with TemporaryDirectory(prefix="refdata_test") as tmp: + refdata_path = get_test_file_path("movies/refdata") + shutil.copytree(refdata_path.as_posix(), tmp, dirs_exist_ok=True) + yield tmp + +@pytest.fixture(scope="function") +def spark_refdata_loader(spark, temp_working_dir): + SparkRefDataLoader.spark = spark + SparkRefDataLoader.dataset_config_uri = temp_working_dir + yield SparkRefDataLoader, temp_working_dir + +@pytest.fixture(scope="function") +def spark_refdata_table(spark_refdata_loader, spark_test_database): + refdata_loader, _ = spark_refdata_loader + tbl = "movies_sequels" + refdata_loader.spark.read.parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).write.saveAsTable(f"{spark_test_database}.{tbl}") + yield spark_test_database, tbl + refdata_loader.spark.sql(f"DROP TABLE IF EXISTS {spark_test_database}.{tbl}") + + +def test_load_parquet_file(spark_refdata_loader): + refdata_loader, _ = spark_refdata_loader + config = { + "test_refdata": ReferenceFile(type="filename", + filename="./movies_sequels.parquet") + } + spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + + test = spk_refdata_loader.load_file(config.get("test_refdata")) + + assert test.count() == 2 + +def test_load_uri_parquet(spark_refdata_loader): + refdata_dir: Path + refdata_loader, refdata_dir = spark_refdata_loader + config = { + "test_refdata": ReferenceURI(type="uri", + uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + } + spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + + test = spk_refdata_loader.load_uri(config.get("test_refdata")) + + assert test.count() == 2 + +def test_table_read(spark_refdata_loader, spark_refdata_table): + refdata_loader, _ = spark_refdata_loader + db, tbl = spark_refdata_table + config = { + "test_refdata": ReferenceTable(type="table", + table_name=tbl, + database=db) + } + spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + + test = spk_refdata_loader.load_table(config.get("test_refdata")) + + assert test.count() == 2 + +def test_via_entity_manager(spark_refdata_loader, spark_refdata_table): + refdata_loader, refdata_dir = spark_refdata_loader + db, tbl = spark_refdata_table + config = { + "test_refdata_file": ReferenceFile(type="filename", + filename="./movies_sequels.parquet"), + "test_refdata_uri": ReferenceURI(type="uri", + uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + "test_refdata_table": ReferenceTable(type="table", + table_name=tbl, + database=db) + } + em = EntityManager({}, reference_data=refdata_loader(config)) + assert em.get("refdata_test_refdata_file").count() == 2 + assert em.get("refdata_test_refdata_uri").count() == 2 + assert em.get("refdata_test_refdata_table").count() == 2 + +def test_refdata_error(spark_refdata_loader): + refdata_loader, _ = spark_refdata_loader + config = { + "test_refdata_file": ReferenceFile(type="filename", + filename="./movies_sequels.arrow") + } + em = EntityManager({}, reference_data=refdata_loader(config)) + with pytest.raises(MissingRefDataEntity): + em["refdata_missing"] + em["refdata_test_refdata_file"] + \ No newline at end of file diff --git a/tests/test_core_engine/test_engine.py b/tests/test_core_engine/test_engine.py index 5e16f09..5118cbd 100644 --- a/tests/test_core_engine/test_engine.py +++ b/tests/test_core_engine/test_engine.py @@ -9,6 +9,7 @@ import pytest from pyspark.sql import SparkSession +from dve.common.error_utils import load_all_error_messages from dve.core_engine.backends.implementations.spark.backend import SparkBackend from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.engine import CoreEngine @@ -25,21 +26,21 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( - dataset_config_path=config_path.as_uri(), + dataset_config_path=config_path.as_posix(), output_prefix=Path(temp_dir), - backend=SparkBackend(dataset_config_uri=config_path.parent.as_uri(), + backend=SparkBackend(dataset_config_uri=config_path.parent.as_posix(), spark_session=spark, reference_data_loader=refdata_loader) ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ - "planets": get_test_file_path("planets/planets_demo.csv").as_uri(), + "planets": get_test_file_path("planets/planets_demo.csv").as_posix(), }, ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() @@ -55,7 +56,7 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): def test_dummy_demographics_run(self, spark, temp_dir: str): """Test that we can still run the test example with the dummy demographics data.""" - config_path = get_test_file_path("demographics/basic_demographics.dischema.json").as_uri() + config_path = get_test_file_path("demographics/basic_demographics.dischema.json").as_posix() with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( @@ -64,15 +65,15 @@ def test_dummy_demographics_run(self, spark, temp_dir: str): ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ "demographics": get_test_file_path( "demographics/basic_demographics.csv" - ).as_uri(), + ).as_posix(), }, ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() @@ -88,7 +89,7 @@ def test_dummy_demographics_run(self, spark, temp_dir: str): def test_dummy_books_run(self, spark, temp_dir: str): """Test that we can handle files with more complex nested schemas.""" - config_path = get_test_file_path("books/nested_books.dischema.json").as_uri() + config_path = get_test_file_path("books/nested_books.dischema.json").as_posix() with warnings.catch_warnings(): warnings.simplefilter("ignore") test_instance = CoreEngine.build( @@ -96,14 +97,14 @@ def test_dummy_books_run(self, spark, temp_dir: str): output_prefix=Path(temp_dir), ) with test_instance: - _, messages = test_instance.run_pipeline( + _, errors_uri = test_instance.run_pipeline( entity_locations={ - "header": get_test_file_path("books/nested_books.xml").as_uri(), - "nested_books": get_test_file_path("books/nested_books.xml").as_uri(), + "header": get_test_file_path("books/nested_books.xml").as_posix(), + "nested_books": get_test_file_path("books/nested_books.xml").as_posix(), } ) - critical_messages = [message for message in messages if message.is_critical] + critical_messages = [message for message in load_all_error_messages(errors_uri) if message.is_critical] assert not critical_messages output_files = Path(temp_dir).iterdir() diff --git a/tests/test_parser/test_file_handling.py b/tests/test_parser/test_file_handling.py index cfa90be..8833d17 100644 --- a/tests/test_parser/test_file_handling.py +++ b/tests/test_parser/test_file_handling.py @@ -32,6 +32,7 @@ resolve_location, ) from dve.parser.file_handling.implementations import S3FilesystemImplementation +from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation, file_uri_to_local_path from dve.parser.file_handling.service import _get_implementation from dve.parser.type_hints import Hostname, Scheme, URIPath @@ -192,11 +193,14 @@ def test_iter_prefix(self, prefix: str): file.write("") actual_nodes = sorted(iter_prefix(prefix, recursive=False)) + cleaned_prefix = ( + file_uri_to_local_path(prefix).as_posix() if type(_get_implementation(prefix)) == LocalFilesystemImplementation + else prefix) expected_nodes = sorted( [ - (prefix + "/test_file.txt", "resource"), - (prefix + "/test_sibling.txt", "resource"), - (prefix + "/test_prefix/", "directory"), + (cleaned_prefix + "/test_file.txt", "resource"), + (cleaned_prefix + "/test_sibling.txt", "resource"), + (cleaned_prefix + "/test_prefix/", "directory"), ] ) assert actual_nodes == expected_nodes @@ -221,14 +225,20 @@ def test_iter_prefix_recursive(self, prefix: str): "/test_prefix/another_level/test_file.txt", "/test_prefix/another_level/nested_sibling.txt", ] - resource_uris = [prefix + uri_path for uri_path in structure] - directory_uris = [prefix + "/test_prefix/", prefix + "/test_prefix/another_level/"] + cleaned_prefix = ( + file_uri_to_local_path(prefix).as_posix() + if type(_get_implementation(prefix)) == LocalFilesystemImplementation + else prefix + ) + + resource_uris = [cleaned_prefix + uri_path for uri_path in structure] + directory_uris = [cleaned_prefix + "/test_prefix/", cleaned_prefix + "/test_prefix/another_level/"] for uri in resource_uris: with open_stream(uri, "w") as file: file.write("") - actual_nodes = sorted(iter_prefix(prefix, recursive=True)) + actual_nodes = sorted(iter_prefix(cleaned_prefix, recursive=True)) expected_nodes = sorted( [(uri, "directory") for uri in directory_uris] + [(uri, "resource") for uri in resource_uris] @@ -409,21 +419,21 @@ def test_cursed_s3_keys_supported(temp_s3_prefix: str): [ ( Path("abc/samples/planet_test_records.xml"), - Path("abc/samples/planet_test_records.xml").resolve().as_uri(), + Path("abc/samples/planet_test_records.xml").resolve().as_posix(), ), - ("file:///home/user/file.txt", Path("/home/user/file.txt").as_uri()), + ("file:///home/user/file.txt", Path("/home/user/file.txt").as_posix()), ("s3://bucket/path/within/bucket/file.csv", "s3://bucket/path/within/bucket/file.csv"), ( "file:///abc/samples/planet_test_records.xml", - "file:///abc/samples/planet_test_records.xml", + "/abc/samples/planet_test_records.xml", ), ( "/abc/samples/planet_test_records.xml", - Path("/abc/samples/planet_test_records.xml").as_uri(), + Path("/abc/samples/planet_test_records.xml").as_posix(), ), ( Path("/abc/samples/planet_test_records.xml"), - Path("/abc/samples/planet_test_records.xml").as_uri(), + Path("/abc/samples/planet_test_records.xml").as_posix(), ), ], # fmt: on diff --git a/tests/test_pipeline/pipeline_helpers.py b/tests/test_pipeline/pipeline_helpers.py index 1518ccf..ddd4ef8 100644 --- a/tests/test_pipeline/pipeline_helpers.py +++ b/tests/test_pipeline/pipeline_helpers.py @@ -172,7 +172,7 @@ def planets_data_after_data_contract() -> Iterator[Tuple[SubmissionInfo, str]]: dataset_id="planets", file_extension="json", ) - output_path = Path(tdir, submitted_file_info.submission_id, "contract", "planets") + output_path = Path(tdir, submitted_file_info.submission_id, "data_contract", "planets") output_path.mkdir(parents=True) planet_contract_data = { @@ -220,7 +220,7 @@ def planets_data_after_data_contract_that_break_business_rules() -> Iterator[ dataset_id="planets", file_extension="json", ) - output_path = Path(tdir, submitted_file_info.submission_id, "contract", "planets") + output_path = Path(tdir, submitted_file_info.submission_id, "data_contract", "planets") output_path.mkdir(parents=True) planet_contract_data = { @@ -398,9 +398,10 @@ def error_data_after_business_rules() -> Iterator[Tuple[SubmissionInfo, str]]: } ]""" ) - output_file_path = output_path / "business_rules_errors.json" + output_file_path = output_path / "business_rules_errors.jsonl" with open(output_file_path, "w", encoding="utf-8") as f: - json.dump(error_data, f) + for entry in error_data: + f.write(json.dumps(entry) + "\n") yield submitted_file_info, tdir diff --git a/tests/test_pipeline/test_duckdb_pipeline.py b/tests/test_pipeline/test_duckdb_pipeline.py index 58eb4ac..29e0734 100644 --- a/tests/test_pipeline/test_duckdb_pipeline.py +++ b/tests/test_pipeline/test_duckdb_pipeline.py @@ -132,7 +132,7 @@ def test_data_contract_step( assert len(success) == 1 assert not success[0][1].validation_failed assert len(failed) == 0 - assert Path(processed_file_path, sub_info.submission_id, "contract", "planets").exists() + assert Path(processed_file_path, sub_info.submission_id, "data_contract", "planets").exists() assert pl_row_count(audit_manager.get_all_business_rule_submissions().pl()) == 1 diff --git a/tests/test_pipeline/test_foundry_ddb_pipeline.py b/tests/test_pipeline/test_foundry_ddb_pipeline.py index 68fec99..12a7fd1 100644 --- a/tests/test_pipeline/test_foundry_ddb_pipeline.py +++ b/tests/test_pipeline/test_foundry_ddb_pipeline.py @@ -121,7 +121,7 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): processing_folder, sub_info.submission_id, "processing_errors", - "processing_errors.json" + "processing_errors.jsonl" ) assert perror_path.exists() perror_schema = { diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 7f4738f..910626a 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -16,10 +16,12 @@ import polars as pl from pyspark.sql import SparkSession +from dve.common.error_utils import load_feedback_messages from dve.core_engine.backends.base.auditing import FilterCriteria from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations +from dve.core_engine.message import UserMessage from dve.core_engine.models import ProcessingStatusRecord, SubmissionInfo, SubmissionStatisticsRecord import dve.parser.file_handling as fh from dve.pipeline.spark_pipeline import SparkDVEPipeline @@ -135,7 +137,7 @@ def test_apply_data_contract_success( assert not sub_status.validation_failed - assert Path(Path(processed_file_path), sub_info.submission_id, "contract", "planets").exists() + assert Path(Path(processed_file_path), sub_info.submission_id, "data_contract", "planets").exists() def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name @@ -157,9 +159,9 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name assert sub_status.validation_failed output_path = Path(processed_file_path) / sub_info.submission_id - assert Path(output_path, "contract", "planets").exists() + assert Path(output_path, "data_contract", "planets").exists() - errors_path = Path(output_path, "errors", "contract_errors.json") + errors_path = Path(output_path, "errors", "data_contract_errors.jsonl") assert errors_path.exists() expected_errors = [ @@ -203,10 +205,10 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name "Category": "Bad value", }, ] - with open(errors_path, "r", encoding="utf-8") as f: - actual_errors = json.load(f) + + actual_errors = list(load_feedback_messages(errors_path.as_posix())) - assert actual_errors == expected_errors + assert actual_errors == [UserMessage(**err) for err in expected_errors] def test_data_contract_step( @@ -234,7 +236,7 @@ def test_data_contract_step( assert not success[0][1].validation_failed assert len(failed) == 0 - assert Path(processed_file_path, sub_info.submission_id, "contract", "planets").exists() + assert Path(processed_file_path, sub_info.submission_id, "data_contract", "planets").exists() assert audit_manager.get_all_business_rule_submissions().count() == 1 audit_result = audit_manager.get_all_error_report_submissions() @@ -329,7 +331,7 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out assert og_planets_entity_path.exists() assert spark.read.parquet(str(og_planets_entity_path)).count() == 1 - errors_path = Path(br_path.parent, "errors", "business_rules_errors.json") + errors_path = Path(br_path.parent, "errors", "business_rules_errors.jsonl") assert errors_path.exists() expected_errors = [ @@ -360,10 +362,10 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out "Category": "Bad value", }, ] - with open(errors_path, "r", encoding="utf-8") as f: - actual_errors = json.load(f) + + actual_errors = list(load_feedback_messages(errors_path.as_posix())) - assert actual_errors == expected_errors + assert actual_errors == [UserMessage(**err) for err in expected_errors] def test_business_rule_step( diff --git a/tests/test_reporting/test_utils.py b/tests/test_reporting/test_error_utils.py similarity index 92% rename from tests/test_reporting/test_utils.py rename to tests/test_reporting/test_error_utils.py index c240ca5..5d61bc7 100644 --- a/tests/test_reporting/test_utils.py +++ b/tests/test_reporting/test_error_utils.py @@ -6,7 +6,7 @@ import polars as pl from dve.core_engine.exceptions import CriticalProcessingError -from dve.reporting.utils import dump_processing_errors +from dve.common.error_utils import dump_processing_errors # pylint: disable=C0116 @@ -44,7 +44,7 @@ def test_dump_processing_errors(): perror_schema ) error_df = pl.read_json( - Path(output_path, "processing_errors.json") + Path(output_path, "processing_errors.jsonl") ) cols_to_check = ["step_name", "error_location", "error_level", "error_message"] diff --git a/tests/testdata/movies/refdata/movies_sequels.arrow b/tests/testdata/movies/refdata/movies_sequels.arrow new file mode 100644 index 0000000..89ec37f Binary files /dev/null and b/tests/testdata/movies/refdata/movies_sequels.arrow differ