diff --git a/map2loop/map2model_wrapper.py b/map2loop/map2model_wrapper.py deleted file mode 100644 index 115b8702..00000000 --- a/map2loop/map2model_wrapper.py +++ /dev/null @@ -1,190 +0,0 @@ -# internal imports -from .m2l_enums import VerboseLevel - -# external imports -import geopandas as gpd -import pandas as pd -import numpy as np - -from .logging import getLogger - -logger = getLogger(__name__) - - -class Map2ModelWrapper: - """ - A wrapper around map2model functionality - - Attributes - ---------- - sorted_units: None or list - map2model's estimate of the stratigraphic column - fault_fault_relationships: None or pandas.DataFrame - data frame of fault to fault relationships with columns ["Fault1", "Fault2", "Type", "Angle"] - unit_fault_relationships: None or pandas.DataFrame - data frame of unit fault relationships with columns ["Unit", "Fault"] - unit_unit_relationships: None or pandas.DataFrame - data frame of unit unit relationships with columns ["Index1", "UnitName1", "Index2", "UnitName2"] - map_data: MapData - A pointer to the map data structure in project - verbose_level: m2l_enum.VerboseLevel - A selection that defines how much console logging is output - """ - - def __init__( - self, map_data, *, verbose_level: VerboseLevel = VerboseLevel.NONE - ): - """ - The initialiser for the map2model wrapper - - Args: - map_data (MapData): - The project map data structure to reference - verbose_level (VerboseLevel, optional): - How much console output is sent. Defaults to VerboseLevel.ALL. - """ - self.sorted_units = None - self._fault_fault_relationships = None - self._unit_fault_relationships = None - self._unit_unit_relationships = None - self.map_data = map_data - self.verbose_level = verbose_level - self.buffer_radius = 500 - - @property - def fault_fault_relationships(self): - if self._fault_fault_relationships is None: - self._calculate_fault_fault_relationships() - - return self._fault_fault_relationships - - @property - def unit_fault_relationships(self): - if self._unit_fault_relationships is None: - self._calculate_fault_unit_relationships() - - return self._unit_fault_relationships - - @property - def unit_unit_relationships(self): - if self._unit_unit_relationships is None: - self._calculate_unit_unit_relationships() - - return self._unit_unit_relationships - - def reset(self): - """ - Reset the wrapper to before the map2model process - """ - logger.info("Resetting map2model wrapper") - self.sorted_units = None - self.fault_fault_relationships = None - self.unit_fault_relationships = None - self.unit_unit_relationships = None - - def get_sorted_units(self): - """ - Getter for the map2model sorted units - - Returns: - list: The map2model stratigraphic column estimate - """ - raise NotImplementedError("This method is not implemented") - - - def get_fault_fault_relationships(self): - """ - Getter for the fault fault relationships - - Returns: - pandas.DataFrame: The fault fault relationships - """ - - return self.fault_fault_relationships - - def get_unit_fault_relationships(self): - """ - Getter for the unit fault relationships - - Returns: - pandas.DataFrame: The unit fault relationships - """ - - return self.unit_fault_relationships - - def get_unit_unit_relationships(self): - """ - Getter for the unit unit relationships - - Returns: - pandas.DataFrame: The unit unit relationships - """ - - return self.unit_unit_relationships - - def _calculate_fault_fault_relationships(self): - - faults = self.map_data.FAULT.copy() - # reset index so that we can index the adjacency matrix with the index - faults.reset_index(inplace=True) - buffers = faults.buffer(self.buffer_radius) - # create the adjacency matrix - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=buffers), gpd.GeoDataFrame(geometry=faults["geometry"]) - ) - intersection["index_left"] = intersection.index - intersection.reset_index(inplace=True) - - adjacency_matrix = np.zeros((faults.shape[0], faults.shape[0]), dtype=bool) - adjacency_matrix[ - intersection.loc[:, "index_left"], intersection.loc[:, "index_right"] - ] = True - f1, f2 = np.where(np.tril(adjacency_matrix, k=-1)) - df = pd.DataFrame( - {'Fault1': faults.loc[f1, 'ID'].to_list(), 'Fault2': faults.loc[f2, 'ID'].to_list()} - ) - df['Angle'] = 60 # make it big to prevent LS from making splays - df['Type'] = 'T' - self._fault_fault_relationships = df - - def _calculate_fault_unit_relationships(self): - """Calculate unit/fault relationships using geopandas sjoin. - This will return - """ - units = self.map_data.GEOLOGY["UNITNAME"].unique() - faults = self.map_data.FAULT.copy().reset_index().drop(columns=['index']) - adjacency_matrix = np.zeros((len(units), faults.shape[0]), dtype=bool) - for i, u in enumerate(units): - unit = self.map_data.GEOLOGY[self.map_data.GEOLOGY["UNITNAME"] == u] - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=faults["geometry"]), - gpd.GeoDataFrame(geometry=unit["geometry"]), - ) - intersection["index_left"] = intersection.index - intersection.reset_index(inplace=True) - adjacency_matrix[i, intersection.loc[:, "index_left"]] = True - u, f = np.where(adjacency_matrix) - df = pd.DataFrame({"Unit": units[u].tolist(), "Fault": faults.loc[f, "ID"].to_list()}) - self._unit_fault_relationships = df - - def _calculate_unit_unit_relationships(self): - if self.map_data.contacts is None: - self.map_data.extract_all_contacts() - self._unit_unit_relationships = self.map_data.contacts.copy().drop( - columns=['length', 'geometry'] - ) - return self._unit_unit_relationships - - def run(self, verbose_level: VerboseLevel = None): - """ - The main execute function that prepares, runs and parse the output of the map2model process - - Args: - verbose_level (VerboseLevel, optional): - How much console output is sent. Defaults to None (which uses the wrapper attribute). - """ - - self.get_fault_fault_relationships() - self.get_unit_fault_relationships() - self.get_unit_unit_relationships() - diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 4137af27..d8ae98c1 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -1360,11 +1360,10 @@ def calculate_bounding_box_and_projection(self): def export_wkt_format_files(self): """ Save out the geology and fault GeoDataFrames in WKT format - This is used by map2model + This is used by topology """ - # TODO: - Move away from tab seperators entirely (topology and map2model) - - self.map2model_tmp_path = pathlib.Path(tempfile.mkdtemp()) + # TODO: - Move away from tab seperators entirely (topology) + self.topology_tmp_path = pathlib.Path(tempfile.mkdtemp()) # Check geology data status and export to a WKT format file self.load_map_data(Datatype.GEOLOGY) @@ -1398,7 +1397,7 @@ def export_wkt_format_files(self): geology["ROCKTYPE1"] = geology["ROCKTYPE1"].replace("", "None") geology["ROCKTYPE2"] = geology["ROCKTYPE2"].replace("", "None") geology.to_csv( - pathlib.Path(self.map2model_tmp_path) / "geology_wkt.csv", sep="\t", index=False + pathlib.Path(self.topology_tmp_path) / "geology_wkt.csv", sep="\t", index=False ) # Check faults data status and export to a WKT format file @@ -1413,7 +1412,7 @@ def export_wkt_format_files(self): faults = self.get_map_data(Datatype.FAULT).copy() faults.rename(columns={"geometry": "WKT"}, inplace=True) faults.to_csv( - pathlib.Path(self.map2model_tmp_path) / "faults_wkt.csv", sep="\t", index=False + pathlib.Path(self.topology_tmp_path) / "faults_wkt.csv", sep="\t", index=False ) @beartype.beartype @@ -1448,63 +1447,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y): val = data.ReadAsArray(px, py, 1, 1)[0][0] return val - @beartype.beartype - def __value_from_raster(self, inv_geotransform, data, x: float, y: float): - """ - Get the value from a raster dataset at the specified point - - Args: - inv_geotransform (gdal.GeoTransform): - The inverse of the data's geotransform - data (numpy.array): - The raster data - x (float): - The easting coordinate of the value - y (float): - The northing coordinate of the value - - Returns: - float or int: The value at the point specified - """ - px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) - py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) - # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP - px = max(px, 0) - px = min(px, data.shape[0] - 1) - py = max(py, 0) - py = min(py, data.shape[1] - 1) - return data[px][py] - - @beartype.beartype - def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): - """ - Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates - - Args: - datatype (Datatype): - The datatype of the raster map to retrieve from - df (pandas.DataFrame): - The original dataframe with 'X' and 'Y' columns - - Returns: - pandas.DataFrame: The modified dataframe - """ - if len(df) <= 0: - df["Z"] = [] - return df - data = self.get_map_data(datatype) - if data is None: - logger.warning("Cannot get value from data as data is not loaded") - return None - - inv_geotransform = gdal.InvGeoTransform(data.GetGeoTransform()) - data_array = numpy.array(data.GetRasterBand(1).ReadAsArray().T) - - df["Z"] = df.apply( - lambda row: self.__value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), - axis=1, - ) - return df @beartype.beartype def extract_all_contacts(self, save_contacts=True): diff --git a/map2loop/project.py b/map2loop/project.py index d9cfbb83..0c01eeb5 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -1,16 +1,16 @@ # internal imports from map2loop.fault_orientation import FaultOrientationNearest -from .utils import hex_to_rgb +from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData -from .sampler import Sampler, SamplerDecimator, SamplerSpacing +from .sampler import sample_data from .thickness_calculator import InterpolatedStructure, ThicknessCalculator from .throw_calculator import ThrowCalculator, ThrowCalculatorAlpha from .fault_orientation import FaultOrientation from .sorter import Sorter, SorterAgeBased, SorterAlpha, SorterUseNetworkX, SorterUseHint from .stratigraphic_column import StratigraphicColumn from .deformation_history import DeformationHistory -from .map2model_wrapper import Map2ModelWrapper +from .topology import run_topology, calculate_unit_unit_relationships from .data_checks import validate_config_dictionary # external imports @@ -39,16 +39,16 @@ class Project(object): ----------- verbose_level: m2l_enums.VerboseLevel A selection that defines how much console logging is output - samplers: Sampler - A list of samplers used to extract point samples from polyonal or line segments. Indexed by m2l_enum.Dataype + samplers: dict + A dictionary to store sampler names and parameters sorter: Sorter The sorting algorithm to use for calculating the stratigraphic column loop_filename: str The name of the loop project file used in this project map_data: MapData The structure that holds all map and dtm data - map2model: Map2ModelWrapper - A wrapper around the map2model module that extracts unit and fault adjacency + topology_results: dict + A dictionary storing results from the topology process stratigraphic_column: StratigraphicColumn The structure that holds the unit information and ordering deformation_history: DeformationHistory @@ -135,7 +135,7 @@ def __init__( self._error_state = ErrorState.NONE self._error_state_msg = "" self.verbose_level = verbose_level - self.samplers = [SamplerDecimator()] * len(Datatype) + self.samplers = {} # Dictionary to store sampler names and parameters self.set_default_samplers() self.bounding_box = bounding_box self.sorter = SorterUseHint() @@ -143,7 +143,7 @@ def __init__( self.throw_calculator = ThrowCalculatorAlpha() self.fault_orientation = FaultOrientationNearest() self.map_data = MapData(verbose_level=verbose_level) - self.map2model = Map2ModelWrapper(self.map_data) + self.topology_results = None self.stratigraphic_column = StratigraphicColumn() self.deformation_history = DeformationHistory(project=self) self.loop_filename = loop_project_filename @@ -397,41 +397,44 @@ def set_default_samplers(self): Initialisation function to set or reset the point samplers """ logger.info("Setting default samplers") - self.samplers[Datatype.STRUCTURE] = SamplerDecimator(1) - self.samplers[Datatype.GEOLOGY] = SamplerSpacing(50.0) - self.samplers[Datatype.FAULT] = SamplerSpacing(50.0) - self.samplers[Datatype.FOLD] = SamplerSpacing(50.0) - self.samplers[Datatype.DTM] = SamplerSpacing(50.0) + self.samplers = { + Datatype.STRUCTURE: {"name": "decimator", "params": {"decimation": 1}}, + Datatype.GEOLOGY: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.FAULT: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.FOLD: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.DTM: {"name": "spacing", "params": {"spacing": 50.0}} + } @beartype.beartype - def set_sampler(self, datatype: Datatype, sampler: Sampler): + def set_sampler(self, datatype: Datatype, sampler_name: str, **kwargs): """ Set the point sampler for a specific datatype Args: datatype (Datatype): The datatype to use this sampler on - sampler (Sampler): - The sampler to use + sampler_name (str): + The name of the sampler to use + **kwargs: + Additional parameters for the sampler """ allowed_samplers = { - Datatype.STRUCTURE: SamplerDecimator, - Datatype.GEOLOGY: SamplerSpacing, - Datatype.FAULT: SamplerSpacing, - Datatype.FOLD: SamplerSpacing, - Datatype.DTM: SamplerSpacing, + Datatype.STRUCTURE: ["decimator"], + Datatype.GEOLOGY: ["spacing"], + Datatype.FAULT: ["spacing"], + Datatype.FOLD: ["spacing"], + Datatype.DTM: ["spacing"], } # Check for wrong sampler if datatype in allowed_samplers: - allowed_sampler_type = allowed_samplers[datatype] - if not isinstance(sampler, allowed_sampler_type): + if sampler_name not in allowed_samplers[datatype]: raise ValueError( - f"Got wrong argument for this datatype: {type(sampler).__name__}, please use {allowed_sampler_type.__name__} instead" + f"Invalid sampler {sampler_name} for datatype {datatype}, please use {allowed_samplers[datatype]} instead" ) - ## does the enum print the number or the label? - logger.info(f"Setting sampler for {datatype} to {sampler.sampler_label}") - self.samplers[datatype] = sampler + + logger.info(f"Setting sampler for {datatype} to {sampler_name}") + self.samplers[datatype] = {"name": sampler_name, "params": kwargs} @beartype.beartype def get_sampler(self, datatype: Datatype): @@ -444,7 +447,7 @@ def get_sampler(self, datatype: Datatype): Returns: str: The name of the sampler being used on the specified datatype """ - return self.samplers[datatype].sampler_label + return self.samplers[datatype]["name"] @beartype.beartype def set_minimum_fault_length(self, length: Union[float, int]): @@ -503,26 +506,53 @@ def sample_map_data(self): """ Use the samplers to extract points along polylines or unit boundaries """ - logger.info( - f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}" - ) - self.geology_samples = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data - ) - logger.info( - f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}" - ) - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample( - self.map_data.get_map_data(Datatype.STRUCTURE), self.map_data - ) - logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") - self.fault_samples = self.samplers[Datatype.FAULT].sample( - self.map_data.get_map_data(Datatype.FAULT), self.map_data - ) - logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}") - self.fold_samples = self.samplers[Datatype.FOLD].sample( - self.map_data.get_map_data(Datatype.FOLD), self.map_data - ) + + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + + try: + logger.info(f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY]['name']}") + self.geology_samples = sample_data( + geology_data,self.samplers[Datatype.GEOLOGY]["name"], + **self.samplers[Datatype.GEOLOGY]["params"] + ) + except Exception as e: + logger.error(f"Error sampling geology map data: {str(e)}") + raise + try: + logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE]['name']}") + self.structure_samples = sample_data( + self.map_data.get_map_data(Datatype.STRUCTURE), + self.samplers[Datatype.STRUCTURE]["name"], + dtm_data=dtm_data, + geology_data=geology_data, + **self.samplers[Datatype.STRUCTURE]["params"] + ) + except Exception as e: + logger.error(f"Error sampling structure map data: {str(e)}") + raise + + try: + logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT]['name']}") + self.fault_samples = sample_data( + self.map_data.get_map_data(Datatype.FAULT), + self.samplers[Datatype.FAULT]["name"], + **self.samplers[Datatype.FAULT]["params"] + ) + except Exception as e: + logger.error(f"Error sampling fault map data: {str(e)}") + raise + + try: + logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD]['name']}") + self.fold_samples = sample_data( + self.map_data.get_map_data(Datatype.FOLD), + self.samplers[Datatype.FOLD]["name"], + **self.samplers[Datatype.FOLD]["params"] + ) + except Exception as e: + logger.error(f"Error sampling fold map data: {str(e)}") + raise def extract_geology_contacts(self): """ @@ -532,11 +562,13 @@ def extract_geology_contacts(self): self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.basal_contacts + self.map_data.sampled_contacts = sample_data( + self.map_data.basal_contacts, + self.samplers[Datatype.GEOLOGY]["name"], + **self.samplers[Datatype.GEOLOGY]["params"] ) - - self.map_data.get_value_from_raster_df(Datatype.DTM, self.map_data.sampled_contacts) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) def calculate_stratigraphic_order(self, take_best=False): """ @@ -551,7 +583,7 @@ def calculate_stratigraphic_order(self, take_best=False): columns = [ sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.map2model.get_unit_unit_relationships(), + calculate_unit_unit_relationships(self.map_data), self.map_data.contacts, self.map_data, ) @@ -581,7 +613,7 @@ def calculate_stratigraphic_order(self, take_best=False): logger.info(f'Calculating stratigraphic column using sorter {self.sorter.sorter_label}') self.stratigraphic_column.column = self.sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.map2model.get_unit_unit_relationships(), + calculate_unit_unit_relationships(self.map_data), self.map_data.contacts, self.map_data, ) @@ -714,7 +746,8 @@ def calculate_fault_orientations(self): self.map_data.get_map_data(Datatype.FAULT_ORIENTATION), self.map_data, ) - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_orientations) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_orientations) else: logger.warning( "No fault orientation data found, skipping fault orientation calculation" @@ -739,7 +772,8 @@ def summarise_fault_data(self): """ Use the fault shapefile to make a summary of each fault by name """ - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_samples) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_samples) self.deformation_history.summarise_data(self.fault_samples) self.deformation_history.faults = self.throw_calculator.compute( @@ -768,7 +802,10 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): self.stratigraphic_column.column = user_defined_stratigraphic_column - self.map2model.run() # if we use a user defined stratigraphic column, we still need to calculate the results of map2model + self.topology_results = run_topology( + self.map_data.FAULT, + self.map_data.GEOLOGY, + ) else: if user_defined_stratigraphic_column is not None: logger.warning( @@ -779,7 +816,11 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): # Calculate basal contacts based on stratigraphic column self.extract_geology_contacts() - self.sample_map_data() + try: + self.sample_map_data() + except Exception as e: + logger.error(f"Error during map data sampling in run_all: {str(e)}") + raise self.calculate_unit_thicknesses() self.calculate_fault_orientations() self.summarise_fault_data() @@ -1012,9 +1053,9 @@ def save_into_projectfile(self): observations["dipPolarity"] = self.structure_samples["OVERTURNED"] LPF.Set(self.loop_filename, "stratigraphicObservations", data=observations) - if self.map2model.fault_fault_relationships is not None: + if self.topology_results and self.topology_results.get("fault_fault_relationships") is not None: ff_relationships = self.deformation_history.get_fault_relationships_with_ids( - self.map2model.fault_fault_relationships + self.topology_results["fault_fault_relationships"] ) relationships = numpy.zeros(len(ff_relationships), LPF.eventRelationshipType) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..0b7cedc6 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,173 +1,165 @@ # internal imports -from .m2l_enums import Datatype -from .mapdata import MapData +from .utils import set_z_values_from_raster_df # external imports -from abc import ABC, abstractmethod import beartype import geopandas import pandas import shapely import numpy from typing import Optional +import osgeo -class Sampler(ABC): +_SAMPLER_REGISTRY = {} + +@beartype.beartype +def register_sampler(name: str): """ - Base Class of Sampler used to force structure of Sampler + Register a sampler function with a given name. Args: - ABC (ABC): Derived from Abstract Base Class + name (str): the name of the sampler """ + def decorator(func): + _SAMPLER_REGISTRY[name] = func + return func + return decorator - def __init__(self): - """ - Initialiser of for Sampler - """ - self.sampler_label = "SamplerBaseClass" - - def type(self): - """ - Getter for subclass type label - - Returns: - str: Name of subclass - """ - return self.sampler_label - - @beartype.beartype - @abstractmethod - def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None - ) -> pandas.DataFrame: - """ - Execute sampling method (abstract method) - - Args: - spatial_data (geopandas.GeoDataFrame): data frame to sample +@beartype.beartype +def get_sampler(name: str): + """ + Get a sampler function by name. - Returns: - pandas.DataFrame: data frame containing samples - """ - pass + Args: + name (str): the name of the sampler to retrieve + Returns: + function: the sampler function + """ + if name not in _SAMPLER_REGISTRY: + raise ValueError(f"Sampler {name} not found") + return _SAMPLER_REGISTRY[name] + +@beartype.beartype +def sample_data( + spatial_data: geopandas.GeoDataFrame, + sampler_name: str, + dtm_data: Optional[osgeo.gdal.Dataset] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, + **kwargs +)-> pandas.DataFrame: + """ + Execute sampling method based on the provided sampler name -class SamplerDecimator(Sampler): + Args: + spatial_data (geopandas.GeoDataFrame): data frame to sample + sampler_name (str): the name of the sampler to use + dtm_data (Optional[osgeo.gdal.Dataset]): dtm data required for decimator sampler + geology_data (Optional[geopandas.GeoDataFrame]): geology data required for decimator sampler + **kwargs: additional arguments to pass to the sampler + + Returns: + pandas.DataFrame: data frame containing samples """ - Decimator sampler class which decimates the geo data frame based on the decimation value - ie. decimation = 10 means take every tenth point - Note: This only works on data frames with lists of points with columns "X" and "Y" + sampler = get_sampler(sampler_name) + if sampler_name == 'decimator': + if dtm_data is None or geology_data is None: + raise ValueError("sample decimator requires both dtm and geology data") + return sampler(spatial_data=spatial_data, dtm_data=dtm_data, geology_data=geology_data, **kwargs) + else: + return sampler(spatial_data=spatial_data, **kwargs) + +@register_sampler("decimator") +@beartype.beartype +def sample_decimator( + spatial_data: geopandas.GeoDataFrame, + dtm_data: osgeo.gdal.Dataset, + geology_data: geopandas.GeoDataFrame, + decimation: int = 1 +) -> pandas.DataFrame: """ + Execute sample method takes full point data, samples the data and returns the decimated points + + Args: + spatial_data (geopandas.GeoDataFrame): data frame to sample + dtm_data (osgeo.gdal.Dataset): dtm data + geology_data (geopandas.GeoDataFrame): geology data + decimation (int, optional): the decimation factor. Default to 1 - @beartype.beartype - def __init__(self, decimation: int = 1): - """ - Initialiser for decimator sampler - - Args: - decimation (int, optional): stride of the points to sample. Defaults to 1. - """ - self.sampler_label = "SamplerDecimator" - decimation = max(decimation, 1) - self.decimation = decimation - - @beartype.beartype - def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None - ) -> pandas.DataFrame: - """ - Execute sample method takes full point data, samples the data and returns the decimated points - - Args: - spatial_data (geopandas.GeoDataFrame): the data frame to sample - - Returns: - pandas.DataFrame: the sampled data points - """ - data = spatial_data.copy() - data["X"] = data.geometry.x - data["Y"] = data.geometry.y - data["Z"] = map_data.get_value_from_raster_df(Datatype.DTM, data)["Z"] - data["layerID"] = geopandas.sjoin( - data, map_data.get_map_data(Datatype.GEOLOGY), how='left' - )['index_right'] - data.reset_index(drop=True, inplace=True) - - return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry")) - - -class SamplerSpacing(Sampler): + Returns: + pandas.DataFrame: the sampled data points """ - Spacing based sampler which decimates the geo data frame based on the distance between points along a line or - in the case of a polygon along the boundary of that polygon - ie. spacing = 500 means take a sample every 500 metres - Note: This only works on data frames that contain MultiPolgon, Polygon, MultiLineString and LineString geometry + decimation = max(decimation, 1) + data = spatial_data.copy() + data["X"] = data.geometry.x + data["Y"] = data.geometry.y + data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + + data["layerID"] = geopandas.sjoin( + data, geology_data, how='left' + )['index_right'] + + data.reset_index(drop=True, inplace=True) + + return pandas.DataFrame(data[:: decimation].drop(columns="geometry")) + +@register_sampler("spacing") +@beartype.beartype +def sample_spacing( + spatial_data: geopandas.GeoDataFrame, + spacing: float = 50.0, +) -> pandas.DataFrame: """ + Execute sample method takes full point data, samples the data and returns the sampled points - @beartype.beartype - def __init__(self, spacing: float = 50.0): - """ - Initialiser for spacing sampler - - Args: - spacing (float, optional): The distance between samples. Defaults to 50.0. - """ - self.sampler_label = "SamplerSpacing" - spacing = max(spacing, 1.0) - self.spacing = spacing - - @beartype.beartype - def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None - ) -> pandas.DataFrame: - """ - Execute sample method takes full point data, samples the data and returns the sampled points - - Args: - spatial_data (geopandas.GeoDataFrame): the data frame to sample (must contain column ["ID"]) - - Returns: - pandas.DataFrame: the sampled data points - """ - schema = {"ID": str, "X": float, "Y": float, "featureId": str} - df = pandas.DataFrame(columns=schema.keys()).astype(schema) - for _, row in spatial_data.iterrows(): - if type(row.geometry) is shapely.geometry.multipolygon.MultiPolygon: - targets = row.geometry.boundary.geoms - elif type(row.geometry) is shapely.geometry.polygon.Polygon: - targets = [row.geometry.boundary] - elif type(row.geometry) is shapely.geometry.multilinestring.MultiLineString: - targets = row.geometry.geoms - elif type(row.geometry) is shapely.geometry.linestring.LineString: - targets = [row.geometry] - else: - targets = [] - - # For the main cases Polygon and LineString the list 'targets' has one element - for a, target in enumerate(targets): - df2 = pandas.DataFrame(columns=schema.keys()).astype(schema) - distances = numpy.arange(0, target.length, self.spacing)[:-1] - points = [target.interpolate(distance) for distance in distances] - df2["X"] = [point.x for point in points] - df2["Y"] = [point.y for point in points] - - # # account for holes//rings in polygons - df2["featureId"] = str(a) - # 1. check if line is "closed" - if target.is_ring: - target_polygon = shapely.geometry.Polygon(target) - if target_polygon.exterior.is_ccw: # if counterclockwise --> hole - for j, target2 in enumerate(targets): - # skip if line or point - if len(target2.coords) >= 2: - continue - # which poly is the hole in? assign featureId of the same poly - t2_polygon = shapely.geometry.Polygon(target2) - if target.within(t2_polygon): # - df2['featureId'] = str(j) - - df2["ID"] = row["ID"] if "ID" in spatial_data.columns else 0 - df = df2 if len(df) == 0 else pandas.concat([df, df2]) - - df.reset_index(drop=True, inplace=True) - return df + Args: + spatial_data (geopandas.GeoDataFrame): the data frame to sample (must contain column ["ID"]) + spacing (float, optional): the spacing between points. Default to 50.0 + + Returns: + pandas.DataFrame: the sampled data points + """ + spacing = max(spacing, 1.0) + schema = {"ID": str, "X": float, "Y": float, "featureId": str} + df = pandas.DataFrame(columns=schema.keys()).astype(schema) + for _, row in spatial_data.iterrows(): + if type(row.geometry) is shapely.geometry.multipolygon.MultiPolygon: + targets = row.geometry.boundary.geoms + elif type(row.geometry) is shapely.geometry.polygon.Polygon: + targets = [row.geometry.boundary] + elif type(row.geometry) is shapely.geometry.multilinestring.MultiLineString: + targets = row.geometry.geoms + elif type(row.geometry) is shapely.geometry.linestring.LineString: + targets = [row.geometry] + else: + targets = [] + + # For the main cases Polygon and LineString the list 'targets' has one element + for a, target in enumerate(targets): + df2 = pandas.DataFrame(columns=schema.keys()).astype(schema) + distances = numpy.arange(0, target.length, spacing)[:-1] + points = [target.interpolate(distance) for distance in distances] + df2["X"] = [point.x for point in points] + df2["Y"] = [point.y for point in points] + + # # account for holes//rings in polygons + df2["featureId"] = str(a) + # 1. check if line is "closed" + if target.is_ring: + target_polygon = shapely.geometry.Polygon(target) + if target_polygon.exterior.is_ccw: # if counterclockwise --> hole + for j, target2 in enumerate(targets): + # skip if line or point + if len(target2.coords) >= 2: + continue + # which poly is the hole in? assign featureId of the same poly + t2_polygon = shapely.geometry.Polygon(target2) + if target.within(t2_polygon): # + df2['featureId'] = str(j) + + df2["ID"] = row["ID"] if "ID" in spatial_data.columns else 0 + df = df2 if len(df) == 0 else pandas.concat([df, df2]) + + df.reset_index(drop=True, inplace=True) + return df diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index d7a9aad1..3da0ad40 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -5,6 +5,7 @@ calculate_endpoints, multiline_to_line, find_segment_strike_from_pt, + set_z_values_from_raster_df ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator @@ -271,7 +272,8 @@ def compute( # set the crs of the contacts to the crs of the units contacts = contacts.set_crs(crs=basal_contacts.crs) # get the elevation Z of the contacts - contacts = map_data.get_value_from_raster_df(Datatype.DTM, contacts) + dtm_data = map_data.get_map_data(Datatype.DTM) + contacts = set_z_values_from_raster_df(dtm_data, contacts) # update the geometry of the contact points to include the Z value contacts["geometry"] = contacts.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 @@ -299,7 +301,8 @@ def compute( # set the crs of the interpolated orientations to the crs of the units interpolated_orientations = interpolated_orientations.set_crs(crs=basal_contacts.crs) # get the elevation Z of the interpolated points - interpolated = map_data.get_value_from_raster_df(Datatype.DTM, interpolated_orientations) + dtm_data = map_data.get_map_data(Datatype.DTM) + interpolated = set_z_values_from_raster_df(dtm_data, interpolated_orientations) # update the geometry of the interpolated points to include the Z value interpolated["geometry"] = interpolated.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 diff --git a/map2loop/topology.py b/map2loop/topology.py new file mode 100644 index 00000000..724f59de --- /dev/null +++ b/map2loop/topology.py @@ -0,0 +1,136 @@ +# internal imports +from .logging import getLogger +from map2loop.map_data import MapData + +# external imports +import inspect +import geopandas +import pandas +import numpy +import beartype +from beartype.typing import Dict + +logger = getLogger(__name__) + +_TOPOLOGY_REGISTRY = {} + +@beartype.beartype +def register_topology(name: str): + """Register a topology function with a given name.""" + def decorator(func): + _TOPOLOGY_REGISTRY[name] = func + return func + return decorator + +@beartype.beartype +def get_topology(name: str): + """Retrieve a registered topology function.""" + if name not in _TOPOLOGY_REGISTRY: + raise ValueError(f"Topology {name} not found") + return _TOPOLOGY_REGISTRY[name] + +@register_topology("fault_fault_relationships") +@beartype.beartype +def calculate_fault_fault_relationships( + fault_layer: geopandas.GeoDataFrame, + buffer_radius: float = 500, +) -> pandas.DataFrame: + """Calculate fault to fault relationships.""" + faults = fault_layer.copy() + faults.reset_index(inplace=True) + buffers = faults.buffer(buffer_radius) + intersection = geopandas.sjoin( + geopandas.GeoDataFrame(geometry=buffers), + geopandas.GeoDataFrame(geometry=faults["geometry"]), + ) + intersection["index_left"] = intersection.index + intersection.reset_index(inplace=True) + + adjacency_matrix = numpy.zeros((faults.shape[0], faults.shape[0]), dtype=bool) + adjacency_matrix[ + intersection.loc[:, "index_left"], + intersection.loc[:, "index_right"], + ] = True + f1, f2 = numpy.where(numpy.tril(adjacency_matrix, k=-1)) + df = pandas.DataFrame( + { + "Fault1": faults.loc[f1, "ID"].to_list(), + "Fault2": faults.loc[f2, "ID"].to_list(), + } + ) + df["Angle"] = 60 + df["Type"] = "T" + return df + +@register_topology("unit_fault_relationships") +@beartype.beartype +def calculate_unit_fault_relationships( + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, + buffer_radius: float = 500, +) -> pandas.DataFrame: + """Calculate unit to fault relationships.""" + units = geology_layer["UNITNAME"].unique() + faults = fault_layer.copy().reset_index().drop(columns=["index"]) + adjacency_matrix = numpy.zeros((len(units), faults.shape[0]), dtype=bool) + for i, u in enumerate(units): + unit = geology_layer[geology_layer["UNITNAME"] == u] + intersection = geopandas.sjoin( + geopandas.GeoDataFrame(geometry=faults["geometry"]), + geopandas.GeoDataFrame(geometry=unit["geometry"]), + ) + intersection["index_left"] = intersection.index + intersection.reset_index(inplace=True) + adjacency_matrix[i, intersection.loc[:, "index_left"]] = True + u_idx, f_idx = numpy.where(adjacency_matrix) + df = pandas.DataFrame({"Unit": units[u_idx].tolist(), "Fault": faults.loc[f_idx, "ID"].to_list()}) + return df +@register_topology("unit_unit_relationships") +@beartype.beartype +def calculate_unit_unit_relationships( + geology_layer: geopandas.GeoDataFrame, + contacts: pandas.DataFrame = None, + map_data: MapData = None, +) -> pandas.DataFrame: + """Calculate unit to unit relationships.""" + if contacts is None: + map_data.extract_all_contacts() + return map_data.contacts.copy().drop(columns=["length", "geometry"]) + +@beartype.beartype +def run_topology( + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, + topology_name: str = "default", + **kwargs, +) -> Dict[str, pandas.DataFrame]: + """Execute a topology function by name.""" + runner = get_topology(topology_name) + signature = inspect.signature(runner) + call_args = {} + if "fault_layer" in signature.parameters: + call_args["fault_layer"] = fault_layer + if "geology_layer" in signature.parameters: + call_args["geology_layer"] = geology_layer + call_args.update(kwargs) + return runner(**call_args) + +@register_topology("default") +@beartype.beartype +def topology_default( + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, + contacts: pandas.DataFrame = None, + map_data: MapData = None, + buffer_radius: float = 500, +) -> Dict[str, pandas.DataFrame]: + """Calculate topology relationships using basic geopandas operations.""" + ff_df = calculate_fault_fault_relationships(fault_layer, buffer_radius) + uf_df = calculate_unit_fault_relationships(fault_layer, geology_layer, buffer_radius) + uu_df = calculate_unit_unit_relationships(geology_layer, contacts, map_data) + + return { + "fault_fault_relationships": ff_df, + "unit_fault_relationships": uf_df, + "unit_unit_relationships": uu_df, + } diff --git a/map2loop/utils.py b/map2loop/utils.py index c3ed7795..fa48cdeb 100644 --- a/map2loop/utils.py +++ b/map2loop/utils.py @@ -7,6 +7,7 @@ import pandas import re import json +from osgeo import gdal from .logging import getLogger logger = getLogger(__name__) @@ -528,3 +529,62 @@ def update_from_legacy_file( json.dump(parsed_data, f, indent=4) return file_map + +@beartype.beartype +def value_from_raster(inv_geotransform, data, x: float, y: float): + """ + Get the value from a raster dataset at the specified point + + Args: + inv_geotransform (gdal.GeoTransform): + The inverse of the data's geotransform + data (numpy.array): + The raster data + x (float): + The easting coordinate of the value + y (float): + The northing coordinate of the value + + Returns: + float or int: The value at the point specified + """ + px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) + py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) + # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP + px = max(px, 0) + px = min(px, data.shape[0] - 1) + py = max(py, 0) + py = min(py, data.shape[1] - 1) + return data[px][py] + +@beartype.beartype +def set_z_values_from_raster_df(dtm_data: gdal.Dataset, df: pandas.DataFrame): + """ + Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates + + Args: + dtm_data (gdal.Dataset): + Dtm data from raster map + df (pandas.DataFrame): + The original dataframe with 'X' and 'Y' columns + + Returns: + pandas.DataFrame: The modified dataframe + """ + if len(df) <= 0: + df["Z"] = [] + return df + + if dtm_data is None: + logger.warning("Cannot get value from data as data is not loaded") + return None + + inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform()) + data_array = numpy.array(dtm_data.GetRasterBand(1).ReadAsArray().T) + + df["Z"] = df.apply( + lambda row: value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), + axis=1, + ) + + return df diff --git a/tests/sampler/test_SamplerSpacing.py b/tests/sampler/test_SamplerSpacing.py index 017f2467..22203850 100644 --- a/tests/sampler/test_SamplerSpacing.py +++ b/tests/sampler/test_SamplerSpacing.py @@ -1,5 +1,5 @@ import pandas -from map2loop.sampler import SamplerSpacing +from map2loop.sampler import sample_data from beartype.roar import BeartypeCallHintParamViolation import pytest import shapely @@ -11,7 +11,7 @@ @pytest.fixture def sampler_spacing(): - return SamplerSpacing(spacing=1.0) + return 1.0 @pytest.fixture @@ -37,7 +37,11 @@ def incorrect_geodata(): # test if correct outputs are generated from the right input def test_sample_function_correct_data(sampler_spacing, correct_geodata): - result = sampler_spacing.sample(correct_geodata) + result = sample_data( + spatial_data = correct_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) assert isinstance(result, pandas.DataFrame) assert 'X' in result.columns assert 'Y' in result.columns @@ -47,25 +51,36 @@ def test_sample_function_correct_data(sampler_spacing, correct_geodata): # add test for incorrect inputs - does it raise a BeartypeCallHintParamViolation error? def test_sample_function_incorrect_data(sampler_spacing, incorrect_geodata): with pytest.raises(BeartypeCallHintParamViolation): - sampler_spacing.sample(spatial_data=incorrect_geodata) + sample_data( + spatial_data = incorrect_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) # for a specific >2 case -def test_sample_function_target_less_than_or_equal_to_2(): - sampler_spacing = SamplerSpacing(spacing=1.0) +def test_sample_function_target_less_than_or_equal_to_2(sampler_spacing): data = { 'geometry': [shapely.geometry.LineString([(0, 0), (0, 1)]), shapely.geometry.LineString([(0, 0), (1, 0)])], 'ID': ['1', '2'], } gdf = geopandas.GeoDataFrame(data, geometry='geometry') - result = sampler_spacing.sample(spatial_data=gdf) + result = sample_data( + spatial_data = gdf, + sampler_name = 'spacing', + spacing = sampler_spacing + ) assert len(result) == 0 # No points should be sampled from the linestring # Test if the extracted points are correct def test_sample_function_extracted_points(sampler_spacing, correct_geodata): - result = sampler_spacing.sample(correct_geodata) + result = sample_data( + spatial_data=correct_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) expected_points = [ (0.0, 0.0), diff --git a/tests/sampler/test_SamplerSpacing_featureId.py b/tests/sampler/test_SamplerSpacing_featureId.py index 73faaa3e..f1b3ffd4 100644 --- a/tests/sampler/test_SamplerSpacing_featureId.py +++ b/tests/sampler/test_SamplerSpacing_featureId.py @@ -1,5 +1,5 @@ import pandas -from map2loop.sampler import SamplerSpacing +from map2loop.sampler import sample_data import shapely import geopandas @@ -9,8 +9,11 @@ sampler_space = 700.0 -sampler = SamplerSpacing(spacing=sampler_space) -geology_samples = sampler.sample(geology_original) +geology_samples = sample_data( + spatial_data = geology_original, + sampler_name = 'spacing', + spacing = sampler_space +) # the actual test: diff --git a/tests/topology/__init__.py b/tests/topology/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/topology/test_topology_functions.py b/tests/topology/test_topology_functions.py new file mode 100644 index 00000000..b0a670c5 --- /dev/null +++ b/tests/topology/test_topology_functions.py @@ -0,0 +1,95 @@ +import sys +import types + +osgeo_stub = types.ModuleType("osgeo") +osgeo_stub.gdal = types.ModuleType("gdal") +osgeo_stub.osr = types.ModuleType("osr") +def _noop(): + pass +osgeo_stub.gdal.UseExceptions = _noop +class _Dataset: + def GetGeoTransform(self): + return (0, 1, 0, 0, 0, 1) +osgeo_stub.gdal.Dataset = _Dataset +def InvGeoTransform(gt): + return gt +osgeo_stub.gdal.InvGeoTransform = InvGeoTransform +sys.modules.setdefault("osgeo", osgeo_stub) +sys.modules.setdefault("osgeo.gdal", osgeo_stub.gdal) +sys.modules.setdefault("osgeo.osr", osgeo_stub.osr) + +import geopandas as gpd +from shapely.geometry import LineString, Polygon +import pandas as pd +from map2loop.topology import ( + calculate_fault_fault_relationships, + calculate_unit_fault_relationships, + calculate_unit_unit_relationships, + register_topology, + run_topology, +) + + +def _create_basic_layers(): + faults = gpd.GeoDataFrame( + { + "geometry": [ + LineString([(0, 0), (1, 0)]), + LineString([(0, 0), (0, 1)]), + ], + "ID": ["F1", "F2"], + }, + geometry="geometry", + crs="EPSG:4326", + ) + + geology = gpd.GeoDataFrame( + { + "UNITNAME": ["U1", "U2"], + "geometry": [ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]), + ], + }, + geometry="geometry", + crs="EPSG:4326", + ) + + return faults, geology + + +def test_calculate_fault_fault_relationships(): + faults, geology = _create_basic_layers() + df = calculate_fault_fault_relationships(faults, buffer_radius=0.1) + assert len(df) == 1 + assert set(df.iloc[0][["Fault1", "Fault2"]]) == {"F1", "F2"} + + +def test_calculate_unit_fault_relationships(): + faults, geology = _create_basic_layers() + df = calculate_unit_fault_relationships(faults, geology, buffer_radius=0.1) + pairs = {tuple(row) for row in df[["Unit", "Fault"]].to_records(index=False)} + assert pairs == {("U1", "F1"), ("U1", "F2"), ("U2", "F1")} + + +def test_calculate_unit_unit_relationships(): + faults, geology = _create_basic_layers() + df = calculate_unit_unit_relationships(geology) + assert list(df.columns) == ["UNITNAME_1", "UNITNAME_2"] + assert df.iloc[0]["UNITNAME_1"] == "U1" + assert df.iloc[0]["UNITNAME_2"] == "U2" + + +def test_registry_and_runner(): + called = {} + + @register_topology("test") + def run_test(fault_layer: gpd.GeoDataFrame, geology_layer: gpd.GeoDataFrame): + called["executed"] = True + return {"dummy": pd.DataFrame()} + + faults, geology = _create_basic_layers() + result = run_topology(faults, geology, "test") + assert "executed" in called + assert list(result.keys()) == ["dummy"] + assert isinstance(result["dummy"], pd.DataFrame)