From 24acd2922a1a1f0c2ea9252dc1305dbd045c9039 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 30 May 2025 23:18:30 +0800 Subject: [PATCH 01/14] sampler --- map2loop/mapdata.py | 57 ----------------------------- map2loop/project.py | 42 ++++++++++------------ map2loop/sampler.py | 22 ++++++++---- map2loop/thickness_calculator.py | 7 ++-- map2loop/utils.py | 62 +++++++++++++++++++++++++++++++- pyproject.toml | 2 +- 6 files changed, 101 insertions(+), 91 deletions(-) diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 884ef3d7..25b6d985 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -1433,63 +1433,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y): val = data.ReadAsArray(px, py, 1, 1)[0][0] return val - @beartype.beartype - def __value_from_raster(self, inv_geotransform, data, x: float, y: float): - """ - Get the value from a raster dataset at the specified point - - Args: - inv_geotransform (gdal.GeoTransform): - The inverse of the data's geotransform - data (numpy.array): - The raster data - x (float): - The easting coordinate of the value - y (float): - The northing coordinate of the value - - Returns: - float or int: The value at the point specified - """ - px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) - py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) - # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP - px = max(px, 0) - px = min(px, data.shape[0] - 1) - py = max(py, 0) - py = min(py, data.shape[1] - 1) - return data[px][py] - - @beartype.beartype - def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): - """ - Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates - - Args: - datatype (Datatype): - The datatype of the raster map to retrieve from - df (pandas.DataFrame): - The original dataframe with 'X' and 'Y' columns - - Returns: - pandas.DataFrame: The modified dataframe - """ - if len(df) <= 0: - df["Z"] = [] - return df - data = self.get_map_data(datatype) - if data is None: - logger.warning("Cannot get value from data as data is not loaded") - return None - - inv_geotransform = gdal.InvGeoTransform(data.GetGeoTransform()) - data_array = numpy.array(data.GetRasterBand(1).ReadAsArray().T) - - df["Z"] = df.apply( - lambda row: self.__value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), - axis=1, - ) - return df @beartype.beartype def extract_all_contacts(self, save_contacts=True): diff --git a/map2loop/project.py b/map2loop/project.py index d9cfbb83..ffa1e1c2 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -1,6 +1,6 @@ # internal imports from map2loop.fault_orientation import FaultOrientationNearest -from .utils import hex_to_rgb +from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData from .sampler import Sampler, SamplerDecimator, SamplerSpacing @@ -506,23 +506,19 @@ def sample_map_data(self): logger.info( f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}" ) - self.geology_samples = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data - ) - logger.info( - f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}" - ) - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample( - self.map_data.get_map_data(Datatype.STRUCTURE), self.map_data - ) + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + + self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) + logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") + + self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") - self.fault_samples = self.samplers[Datatype.FAULT].sample( - self.map_data.get_map_data(Datatype.FAULT), self.map_data - ) + + self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}") - self.fold_samples = self.samplers[Datatype.FOLD].sample( - self.map_data.get_map_data(Datatype.FOLD), self.map_data - ) + + self.fold_samples = self.samplers[Datatype.FOLD].sample(self.map_data.get_map_data(Datatype.FOLD)) def extract_geology_contacts(self): """ @@ -532,11 +528,9 @@ def extract_geology_contacts(self): self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.basal_contacts - ) - - self.map_data.get_value_from_raster_df(Datatype.DTM, self.map_data.sampled_contacts) + self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) def calculate_stratigraphic_order(self, take_best=False): """ @@ -714,7 +708,8 @@ def calculate_fault_orientations(self): self.map_data.get_map_data(Datatype.FAULT_ORIENTATION), self.map_data, ) - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_orientations) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_orientations) else: logger.warning( "No fault orientation data found, skipping fault orientation calculation" @@ -739,7 +734,8 @@ def summarise_fault_data(self): """ Use the fault shapefile to make a summary of each fault by name """ - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_samples) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_samples) self.deformation_history.summarise_data(self.fault_samples) self.deformation_history.faults = self.throw_calculator.compute( diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..43db952e 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,7 @@ # internal imports from .m2l_enums import Datatype from .mapdata import MapData +from .utils import set_z_values_from_raster_df # external imports from abc import ABC, abstractmethod @@ -10,6 +11,7 @@ import shapely import numpy from typing import Optional +from osgeo import gdal class Sampler(ABC): @@ -38,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -73,7 +75,7 @@ def __init__(self, decimation: int = 1): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the decimated points @@ -87,10 +89,16 @@ def sample( data = spatial_data.copy() data["X"] = data.geometry.x data["Y"] = data.geometry.y - data["Z"] = map_data.get_value_from_raster_df(Datatype.DTM, data)["Z"] - data["layerID"] = geopandas.sjoin( - data, map_data.get_map_data(Datatype.GEOLOGY), how='left' - )['index_right'] + if dtm_data is not None: + data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + else: + data["Z"] = None + if geology_data is not None: + data["layerID"] = geopandas.sjoin( + data, geology_data, how='left' + )['index_right'] + else: + data["layerID"] = None data.reset_index(drop=True, inplace=True) return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry")) @@ -118,7 +126,7 @@ def __init__(self, spacing: float = 50.0): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index eb8a2a67..b869e9d5 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -5,6 +5,7 @@ calculate_endpoints, multiline_to_line, find_segment_strike_from_pt, + set_z_values_from_raster_df ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator @@ -271,7 +272,8 @@ def compute( # set the crs of the contacts to the crs of the units contacts = contacts.set_crs(crs=basal_contacts.crs) # get the elevation Z of the contacts - contacts = map_data.get_value_from_raster_df(Datatype.DTM, contacts) + dtm_data = map_data.get_map_data(Datatype.DTM) + contacts = set_z_values_from_raster_df(dtm_data, contacts) # update the geometry of the contact points to include the Z value contacts["geometry"] = contacts.apply( lambda row: shapely.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 @@ -299,7 +301,8 @@ def compute( # set the crs of the interpolated orientations to the crs of the units interpolated_orientations = interpolated_orientations.set_crs(crs=basal_contacts.crs) # get the elevation Z of the interpolated points - interpolated = map_data.get_value_from_raster_df(Datatype.DTM, interpolated_orientations) + dtm_data = map_data.get_map_data(Datatype.DTM) + interpolated = set_z_values_from_raster_df(dtm_data, interpolated_orientations) # update the geometry of the interpolated points to include the Z value interpolated["geometry"] = interpolated.apply( lambda row: shapely.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 diff --git a/map2loop/utils.py b/map2loop/utils.py index 00e8fa57..4305adca 100644 --- a/map2loop/utils.py +++ b/map2loop/utils.py @@ -7,6 +7,7 @@ import pandas import re import json +from osgeo import gdal from .logging import getLogger logger = getLogger(__name__) @@ -527,4 +528,63 @@ def update_from_legacy_file( with open(json_save_path, "w") as f: json.dump(parsed_data, f, indent=4) - return file_map \ No newline at end of file + return file_map + +@beartype.beartype +def value_from_raster(inv_geotransform, data, x: float, y: float): + """ + Get the value from a raster dataset at the specified point + + Args: + inv_geotransform (gdal.GeoTransform): + The inverse of the data's geotransform + data (numpy.array): + The raster data + x (float): + The easting coordinate of the value + y (float): + The northing coordinate of the value + + Returns: + float or int: The value at the point specified + """ + px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) + py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) + # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP + px = max(px, 0) + px = min(px, data.shape[0] - 1) + py = max(py, 0) + py = min(py, data.shape[1] - 1) + return data[px][py] + +@beartype.beartype +def set_z_values_from_raster_df(dtm_data: gdal.Dataset, df: pandas.DataFrame): + """ + Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates + + Args: + dtm_data (gdal.Dataset): + Dtm data from raster map + df (pandas.DataFrame): + The original dataframe with 'X' and 'Y' columns + + Returns: + pandas.DataFrame: The modified dataframe + """ + if len(df) <= 0: + df["Z"] = [] + return df + + if dtm_data is None: + logger.warning("Cannot get value from data as data is not loaded") + return None + + inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform()) + data_array = numpy.array(dtm_data.GetRasterBand(1).ReadAsArray().T) + + df["Z"] = df.apply( + lambda row: value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), + axis=1, + ) + + return df \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4ecb4207..eb1b8498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ name = 'map2loop' description = 'Generate 3D model data from 2D maps.' authors = [{name = 'Loop team'}] readme = 'README.md' -requires-python = '>=3.8' +requires-python = '>=3.8,<3.13' keywords = [ "earth sciences", "geology", "3-D modelling", From 4350df3e818764fe854d3f2d3c702f088785f67a Mon Sep 17 00:00:00 2001 From: noellehmcheng <143368485+noellehmcheng@users.noreply.github.com> Date: Fri, 30 May 2025 15:23:16 +0000 Subject: [PATCH 02/14] style: style fixes by ruff and autoformatting by black --- map2loop/sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 43db952e..f637a72b 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,4 @@ # internal imports -from .m2l_enums import Datatype -from .mapdata import MapData from .utils import set_z_values_from_raster_df # external imports From 585cc60a956a7eb70e94d60ba234cb2b748a5cf5 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 3 Jun 2025 11:06:03 +0800 Subject: [PATCH 03/14] fix workflow --- .github/workflows/linting_and_testing.yml | 32 ++++++++++++++++++----- pyproject.toml | 2 +- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/.github/workflows/linting_and_testing.yml b/.github/workflows/linting_and_testing.yml index dc2ed2fe..1f8f32a8 100644 --- a/.github/workflows/linting_and_testing.yml +++ b/.github/workflows/linting_and_testing.yml @@ -9,6 +9,10 @@ jobs: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' - name: Install dependencies run: | python -m pip install --upgrade pip @@ -25,25 +29,39 @@ jobs: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' - name: Install GDAL run: | sudo add-apt-repository ppa:ubuntugis/ubuntugis-unstable sudo apt-get update sudo apt-get install -y libgdal-dev gdal-bin + - name: Set up Miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + python-version: '3.12' + miniforge-version: latest + activate-environment: test-env + use-mamba: true + auto-activate-base: false + - name: Install dependencies + shell: bash -l {0} run: | - conda update -n base -c defaults conda -y - conda install -n base conda-libmamba-solver -c conda-forge -y - conda install -c conda-forge gdal -y - conda install -c conda-forge -c loop3d --file dependencies.txt -y - conda install pytest -y + mamba install python=3.12 -y + mamba install -c conda-forge gdal geopandas shapely networkx owslib beartype pytest scikit-learn -y + pip install map2model loopprojectfile==0.2.2 - name: Install map2loop + shell: bash -l {0} run: | python -m pip install . - name: Run tests + shell: bash -l {0} run: | - pytest - + python -c "import map2model" || echo "map2model not available, tests will use fallback mode" + pytest -v \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eb1b8498..4ecb4207 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ name = 'map2loop' description = 'Generate 3D model data from 2D maps.' authors = [{name = 'Loop team'}] readme = 'README.md' -requires-python = '>=3.8,<3.13' +requires-python = '>=3.8' keywords = [ "earth sciences", "geology", "3-D modelling", From 029343906a8ad469c291919e73d1e6239589680a Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Mon, 16 Jun 2025 12:50:18 +0800 Subject: [PATCH 04/14] refactor sampler classes to functions --- map2loop/sampler.py | 276 ++++++++++++++++++++------------------------ 1 file changed, 126 insertions(+), 150 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index f637a72b..4e100782 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -12,168 +12,144 @@ from osgeo import gdal -class Sampler(ABC): +_SAMPLER_REGISTRY = {} + +@beartype.beartype +def register_sampler(name: str): """ - Base Class of Sampler used to force structure of Sampler + Register a sampler function with a given name. Args: - ABC (ABC): Derived from Abstract Base Class + name (str): the name of the sampler """ + def decorator(func): + _SAMPLER_REGISTRY[name] = func + return func + return decorator - def __init__(self): - """ - Initialiser of for Sampler - """ - self.sampler_label = "SamplerBaseClass" - - def type(self): - """ - Getter for subclass type label - - Returns: - str: Name of subclass - """ - return self.sampler_label - - @beartype.beartype - @abstractmethod - def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None - ) -> pandas.DataFrame: - """ - Execute sampling method (abstract method) - - Args: - spatial_data (geopandas.GeoDataFrame): data frame to sample - - Returns: - pandas.DataFrame: data frame containing samples - """ - pass - +@beartype.beartype +def get_sampler(name: str): + """ + Get a sampler function by name. -class SamplerDecimator(Sampler): + Args: + name (str): the name of the sampler to retrieve """ - Decimator sampler class which decimates the geo data frame based on the decimation value - ie. decimation = 10 means take every tenth point - Note: This only works on data frames with lists of points with columns "X" and "Y" + if name not in _SAMPLER_REGISTRY: + raise ValueError(f"Sampler {name} not found") + return _SAMPLER_REGISTRY[name] + +@beartype.beartype +def sample_data( + spatial_data: geopandas.GeoDataFrame, + sampler_name: str, + dtm_data: Optional[geopandas.GeoDataFrame] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, + **kwargs +)-> pandas.DataFrame: """ + Execute sampling method (abstract method) - @beartype.beartype - def __init__(self, decimation: int = 1): - """ - Initialiser for decimator sampler - - Args: - decimation (int, optional): stride of the points to sample. Defaults to 1. - """ - self.sampler_label = "SamplerDecimator" - decimation = max(decimation, 1) - self.decimation = decimation - - @beartype.beartype - def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None - ) -> pandas.DataFrame: - """ - Execute sample method takes full point data, samples the data and returns the decimated points - - Args: - spatial_data (geopandas.GeoDataFrame): the data frame to sample - - Returns: - pandas.DataFrame: the sampled data points - """ - data = spatial_data.copy() - data["X"] = data.geometry.x - data["Y"] = data.geometry.y - if dtm_data is not None: - data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] - else: - data["Z"] = None - if geology_data is not None: - data["layerID"] = geopandas.sjoin( - data, geology_data, how='left' - )['index_right'] - else: - data["layerID"] = None - data.reset_index(drop=True, inplace=True) + Args: + spatial_data (geopandas.GeoDataFrame): data frame to sample - return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry")) + Returns: + pandas.DataFrame: data frame containing samples + """ + sampler = get_sampler(sampler_name) + if sampler_name == 'decimator': + if dtm_data is None or geology_data is None: + raise ValueError("sample decimator requires both dtm and geology data") + return sampler(spatial_data=spatial_data, dtm_data=dtm_data, geology_data=geology_data, **kwargs) + else: + return sampler(spatial_data=spatial_data, **kwargs) + +@register_sampler("decimator") +@beartype.beartype +def sample_decimator( + spatial_data: geopandas.GeoDataFrame, + dtm_data: gdal.Dataset, + geology_data: geopandas.GeoDataFrame, + decimation: int = 1 +) -> pandas.DataFrame: + """ + Execute sample method takes full point data, samples the data and returns the decimated points + Args: + spatial_data (geopandas.GeoDataFrame): the data frame to sample -class SamplerSpacing(Sampler): + Returns: + pandas.DataFrame: the sampled data points """ - Spacing based sampler which decimates the geo data frame based on the distance between points along a line or - in the case of a polygon along the boundary of that polygon - ie. spacing = 500 means take a sample every 500 metres - Note: This only works on data frames that contain MultiPolgon, Polygon, MultiLineString and LineString geometry + decimation = max(decimation, 1) + data = spatial_data.copy() + data["X"] = data.geometry.x + data["Y"] = data.geometry.y + data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + + data["layerID"] = geopandas.sjoin( + data, geology_data, how='left' + )['index_right'] + + data.reset_index(drop=True, inplace=True) + + return pandas.DataFrame(data[:: decimation].drop(columns="geometry")) + +@register_sampler("spacing") +@beartype.beartype +def sample_spacing( + spatial_data: geopandas.GeoDataFrame, + spacing: float = 50.0, +) -> pandas.DataFrame: """ + Execute sample method takes full point data, samples the data and returns the sampled points - @beartype.beartype - def __init__(self, spacing: float = 50.0): - """ - Initialiser for spacing sampler - - Args: - spacing (float, optional): The distance between samples. Defaults to 50.0. - """ - self.sampler_label = "SamplerSpacing" - spacing = max(spacing, 1.0) - self.spacing = spacing - - @beartype.beartype - def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None - ) -> pandas.DataFrame: - """ - Execute sample method takes full point data, samples the data and returns the sampled points - - Args: - spatial_data (geopandas.GeoDataFrame): the data frame to sample (must contain column ["ID"]) - - Returns: - pandas.DataFrame: the sampled data points - """ - schema = {"ID": str, "X": float, "Y": float, "featureId": str} - df = pandas.DataFrame(columns=schema.keys()).astype(schema) - for _, row in spatial_data.iterrows(): - if type(row.geometry) is shapely.geometry.multipolygon.MultiPolygon: - targets = row.geometry.boundary.geoms - elif type(row.geometry) is shapely.geometry.polygon.Polygon: - targets = [row.geometry.boundary] - elif type(row.geometry) is shapely.geometry.multilinestring.MultiLineString: - targets = row.geometry.geoms - elif type(row.geometry) is shapely.geometry.linestring.LineString: - targets = [row.geometry] - else: - targets = [] - - # For the main cases Polygon and LineString the list 'targets' has one element - for a, target in enumerate(targets): - df2 = pandas.DataFrame(columns=schema.keys()).astype(schema) - distances = numpy.arange(0, target.length, self.spacing)[:-1] - points = [target.interpolate(distance) for distance in distances] - df2["X"] = [point.x for point in points] - df2["Y"] = [point.y for point in points] - - # # account for holes//rings in polygons - df2["featureId"] = str(a) - # 1. check if line is "closed" - if target.is_ring: - target_polygon = shapely.geometry.Polygon(target) - if target_polygon.exterior.is_ccw: # if counterclockwise --> hole - for j, target2 in enumerate(targets): - # skip if line or point - if len(target2.coords) >= 2: - continue - # which poly is the hole in? assign featureId of the same poly - t2_polygon = shapely.geometry.Polygon(target2) - if target.within(t2_polygon): # - df2['featureId'] = str(j) - - df2["ID"] = row["ID"] if "ID" in spatial_data.columns else 0 - df = df2 if len(df) == 0 else pandas.concat([df, df2]) - - df.reset_index(drop=True, inplace=True) - return df + Args: + spatial_data (geopandas.GeoDataFrame): the data frame to sample (must contain column ["ID"]) + + Returns: + pandas.DataFrame: the sampled data points + """ + spacing = max(spacing, 1.0) + schema = {"ID": str, "X": float, "Y": float, "featureId": str} + df = pandas.DataFrame(columns=schema.keys()).astype(schema) + for _, row in spatial_data.iterrows(): + if type(row.geometry) is shapely.geometry.multipolygon.MultiPolygon: + targets = row.geometry.boundary.geoms + elif type(row.geometry) is shapely.geometry.polygon.Polygon: + targets = [row.geometry.boundary] + elif type(row.geometry) is shapely.geometry.multilinestring.MultiLineString: + targets = row.geometry.geoms + elif type(row.geometry) is shapely.geometry.linestring.LineString: + targets = [row.geometry] + else: + targets = [] + + # For the main cases Polygon and LineString the list 'targets' has one element + for a, target in enumerate(targets): + df2 = pandas.DataFrame(columns=schema.keys()).astype(schema) + distances = numpy.arange(0, target.length, spacing)[:-1] + points = [target.interpolate(distance) for distance in distances] + df2["X"] = [point.x for point in points] + df2["Y"] = [point.y for point in points] + + # # account for holes//rings in polygons + df2["featureId"] = str(a) + # 1. check if line is "closed" + if target.is_ring: + target_polygon = shapely.geometry.Polygon(target) + if target_polygon.exterior.is_ccw: # if counterclockwise --> hole + for j, target2 in enumerate(targets): + # skip if line or point + if len(target2.coords) >= 2: + continue + # which poly is the hole in? assign featureId of the same poly + t2_polygon = shapely.geometry.Polygon(target2) + if target.within(t2_polygon): # + df2['featureId'] = str(j) + + df2["ID"] = row["ID"] if "ID" in spatial_data.columns else 0 + df = df2 if len(df) == 0 else pandas.concat([df, df2]) + + df.reset_index(drop=True, inplace=True) + return df From e01d2b77e91f58ef5d35c13f3cb8faa7bf4131ab Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Mon, 16 Jun 2025 12:51:55 +0800 Subject: [PATCH 05/14] refactor project class to use function based samplers --- map2loop/project.py | 120 ++++++++++++++++++++++++++++++-------------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index ffa1e1c2..0e2830dc 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -3,7 +3,7 @@ from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData -from .sampler import Sampler, SamplerDecimator, SamplerSpacing +from .sampler import sample_data, get_sampler from .thickness_calculator import InterpolatedStructure, ThicknessCalculator from .throw_calculator import ThrowCalculator, ThrowCalculatorAlpha from .fault_orientation import FaultOrientation @@ -39,8 +39,8 @@ class Project(object): ----------- verbose_level: m2l_enums.VerboseLevel A selection that defines how much console logging is output - samplers: Sampler - A list of samplers used to extract point samples from polyonal or line segments. Indexed by m2l_enum.Dataype + samplers: dict + A dictionary to store sampler names and parameters sorter: Sorter The sorting algorithm to use for calculating the stratigraphic column loop_filename: str @@ -135,7 +135,7 @@ def __init__( self._error_state = ErrorState.NONE self._error_state_msg = "" self.verbose_level = verbose_level - self.samplers = [SamplerDecimator()] * len(Datatype) + self.samplers = {} # Dictionary to store sampler names and parameters self.set_default_samplers() self.bounding_box = bounding_box self.sorter = SorterUseHint() @@ -397,41 +397,44 @@ def set_default_samplers(self): Initialisation function to set or reset the point samplers """ logger.info("Setting default samplers") - self.samplers[Datatype.STRUCTURE] = SamplerDecimator(1) - self.samplers[Datatype.GEOLOGY] = SamplerSpacing(50.0) - self.samplers[Datatype.FAULT] = SamplerSpacing(50.0) - self.samplers[Datatype.FOLD] = SamplerSpacing(50.0) - self.samplers[Datatype.DTM] = SamplerSpacing(50.0) + self.samplers = { + Datatype.STRUCTURE: {"name": "decimator", "params": {"decimation": 1}}, + Datatype.GEOLOGY: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.FAULT: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.FOLD: {"name": "spacing", "params": {"spacing": 50.0}}, + Datatype.DTM: {"name": "spacing", "params": {"spacing": 50.0}} + } @beartype.beartype - def set_sampler(self, datatype: Datatype, sampler: Sampler): + def set_sampler(self, datatype: Datatype, sampler_name: str, **kwargs): """ Set the point sampler for a specific datatype Args: datatype (Datatype): The datatype to use this sampler on - sampler (Sampler): - The sampler to use + sampler_name (str): + The name of the sampler to use + **kwargs: + Additional parameters for the sampler """ allowed_samplers = { - Datatype.STRUCTURE: SamplerDecimator, - Datatype.GEOLOGY: SamplerSpacing, - Datatype.FAULT: SamplerSpacing, - Datatype.FOLD: SamplerSpacing, - Datatype.DTM: SamplerSpacing, + Datatype.STRUCTURE: ["decimator"], + Datatype.GEOLOGY: ["spacing"], + Datatype.FAULT: ["spacing"], + Datatype.FOLD: ["spacing"], + Datatype.DTM: ["spacing"], } # Check for wrong sampler if datatype in allowed_samplers: - allowed_sampler_type = allowed_samplers[datatype] - if not isinstance(sampler, allowed_sampler_type): + if sampler_name not in allowed_samplers[datatype]: raise ValueError( - f"Got wrong argument for this datatype: {type(sampler).__name__}, please use {allowed_sampler_type.__name__} instead" + f"Invalid sampler {sampler_name} for datatype {datatype}, please use {allowed_samplers[datatype]} instead" ) - ## does the enum print the number or the label? - logger.info(f"Setting sampler for {datatype} to {sampler.sampler_label}") - self.samplers[datatype] = sampler + + logger.info(f"Setting sampler for {datatype} to {sampler_name}") + self.samplers[datatype] = {"name": sampler_name, "params": kwargs} @beartype.beartype def get_sampler(self, datatype: Datatype): @@ -444,7 +447,7 @@ def get_sampler(self, datatype: Datatype): Returns: str: The name of the sampler being used on the specified datatype """ - return self.samplers[datatype].sampler_label + return self.samplers[datatype]["name"] @beartype.beartype def set_minimum_fault_length(self, length: Union[float, int]): @@ -503,22 +506,53 @@ def sample_map_data(self): """ Use the samplers to extract points along polylines or unit boundaries """ - logger.info( - f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}" - ) + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) dtm_data = self.map_data.get_map_data(Datatype.DTM) - - self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) - logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") - - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) - logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") - - self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) - logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}") - self.fold_samples = self.samplers[Datatype.FOLD].sample(self.map_data.get_map_data(Datatype.FOLD)) + try: + logger.info(f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY]['name']}") + self.geology_samples = sample_data( + geology_data,self.samplers[Datatype.GEOLOGY]["name"], + **self.samplers[Datatype.GEOLOGY]["params"] + ) + except Exception as e: + logger.error(f"Error sampling geology map data: {str(e)}") + raise + try: + logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE]['name']}") + self.structure_samples = sample_data( + self.map_data.get_map_data(Datatype.STRUCTURE), + self.samplers[Datatype.STRUCTURE]["name"], + dtm_data=dtm_data, + geology_data=geology_data, + **self.samplers[Datatype.STRUCTURE]["params"] + ) + except Exception as e: + logger.error(f"Error sampling structure map data: {str(e)}") + raise + + try: + logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT]['name']}") + self.fault_samples = sample_data( + self.map_data.get_map_data(Datatype.FAULT), + self.samplers[Datatype.FAULT]["name"], + **self.samplers[Datatype.FAULT]["params"] + ) + except Exception as e: + logger.error(f"Error sampling fault map data: {str(e)}") + raise + + try: + logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD]['name']}") + self.fold_samples = sample_data( + self.map_data.get_map_data(Datatype.FOLD), + self.samplers[Datatype.FOLD]["name"], + **self.samplers[Datatype.FOLD]["params"] + ) + except Exception as e: + logger.error(f"Error sampling fold map data: {str(e)}") + raise def extract_geology_contacts(self): """ @@ -528,7 +562,11 @@ def extract_geology_contacts(self): self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) + self.map_data.sampled_contacts = sample_data( + self.map_data.basal_contacts, + self.samplers[Datatype.GEOLOGY]["name"], + **self.samplers[Datatype.GEOLOGY]["params"] + ) dtm_data = self.map_data.get_map_data(Datatype.DTM) set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) @@ -775,7 +813,11 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): # Calculate basal contacts based on stratigraphic column self.extract_geology_contacts() - self.sample_map_data() + try: + self.sample_map_data() + except Exception as e: + logger.error(f"Error during map data sampling in run_all: {str(e)}") + raise self.calculate_unit_thicknesses() self.calculate_fault_orientations() self.summarise_fault_data() From 1384f449728c6e2bfd65f91b5e4c76acce7b64e8 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 17 Jun 2025 16:50:35 +0800 Subject: [PATCH 06/14] update ci from 3.2.3_dev --- .github/workflows/linting_and_testing.yml | 62 +++++++++++------------ 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/.github/workflows/linting_and_testing.yml b/.github/workflows/linting_and_testing.yml index 1f8f32a8..c8c89126 100644 --- a/.github/workflows/linting_and_testing.yml +++ b/.github/workflows/linting_and_testing.yml @@ -1,7 +1,20 @@ name: Linting and Testing on: - [push] + push: + branches: + - master + paths: + - '**.py' + - .github/workflows/linting_and_testing.yml + + pull_request: + branches: + - master + paths: + - '**.py' + - .github/workflows/linting_and_testing.yml + workflow_dispatch: jobs: linting: @@ -9,10 +22,6 @@ jobs: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.12' - name: Install dependencies run: | python -m pip install --upgrade pip @@ -25,43 +34,32 @@ jobs: testing: - name: Testing - runs-on: ubuntu-24.04 + name: Testing${{ matrix.os }} python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ${{ fromJSON(vars.BUILD_OS)}} + python-version: ${{ fromJSON(vars.PYTHON_VERSIONS)}} + steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: conda-incubator/setup-miniconda@v3 with: - python-version: '3.12' - - name: Install GDAL - run: | - sudo add-apt-repository ppa:ubuntugis/ubuntugis-unstable - sudo apt-get update - sudo apt-get install -y libgdal-dev gdal-bin + python-version: ${{ matrix.python-version }} + conda-remove-defaults: "true" - - name: Set up Miniconda - uses: conda-incubator/setup-miniconda@v2 - with: - python-version: '3.12' - miniforge-version: latest - activate-environment: test-env - use-mamba: true - auto-activate-base: false - name: Install dependencies - shell: bash -l {0} run: | - mamba install python=3.12 -y - mamba install -c conda-forge gdal geopandas shapely networkx owslib beartype pytest scikit-learn -y - pip install map2model loopprojectfile==0.2.2 + conda run -n test conda info + conda run -n test conda install -c conda-forge -c loop3d --file dependencies.txt gdal python=${{ matrix.python-version }} pytest -y - name: Install map2loop - shell: bash -l {0} run: | - python -m pip install . + conda run -n test python -m pip install . - name: Run tests - shell: bash -l {0} run: | - python -c "import map2model" || echo "map2model not available, tests will use fallback mode" - pytest -v \ No newline at end of file + conda run -n test pytest + From 7f90abb312fb57151e1a953d78d6ab6555d29b20 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 17 Jun 2025 18:36:02 +0800 Subject: [PATCH 07/14] refactor sampler tests to use function based samplers --- tests/sampler/test_SamplerSpacing.py | 31 ++++++++++++++----- .../sampler/test_SamplerSpacing_featureId.py | 9 ++++-- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/sampler/test_SamplerSpacing.py b/tests/sampler/test_SamplerSpacing.py index e69c7650..f02dcea8 100644 --- a/tests/sampler/test_SamplerSpacing.py +++ b/tests/sampler/test_SamplerSpacing.py @@ -1,5 +1,5 @@ import pandas -from map2loop.sampler import SamplerSpacing +from map2loop.sampler import sample_data from beartype.roar import BeartypeCallHintParamViolation import pytest import shapely @@ -11,7 +11,7 @@ @pytest.fixture def sampler_spacing(): - return SamplerSpacing(spacing=1.0) + return 1.0 @pytest.fixture @@ -37,7 +37,11 @@ def incorrect_geodata(): # test if correct outputs are generated from the right input def test_sample_function_correct_data(sampler_spacing, correct_geodata): - result = sampler_spacing.sample(correct_geodata) + result = sample_data( + spatial_data = correct_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) assert isinstance(result, pandas.DataFrame) assert 'X' in result.columns assert 'Y' in result.columns @@ -47,25 +51,36 @@ def test_sample_function_correct_data(sampler_spacing, correct_geodata): # add test for incorrect inputs - does it raise a BeartypeCallHintParamViolation error? def test_sample_function_incorrect_data(sampler_spacing, incorrect_geodata): with pytest.raises(BeartypeCallHintParamViolation): - sampler_spacing.sample(spatial_data=incorrect_geodata) + sample_data( + spatial_data = incorrect_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) # for a specific >2 case -def test_sample_function_target_less_than_or_equal_to_2(): - sampler_spacing = SamplerSpacing(spacing=1.0) +def test_sample_function_target_less_than_or_equal_to_2(sampler_spacing): data = { 'geometry': [shapely.LineString([(0, 0), (0, 1)]), shapely.LineString([(0, 0), (1, 0)])], 'ID': ['1', '2'], } gdf = geopandas.GeoDataFrame(data, geometry='geometry') - result = sampler_spacing.sample(spatial_data=gdf) + result = sample_data( + spatial_data = gdf, + sampler_name = 'spacing', + spacing = sampler_spacing + ) assert len(result) == 0 # No points should be sampled from the linestring # Test if the extracted points are correct def test_sample_function_extracted_points(sampler_spacing, correct_geodata): - result = sampler_spacing.sample(correct_geodata) + result = sample_data( + spatial_data=correct_geodata, + sampler_name = 'spacing', + spacing = sampler_spacing + ) expected_points = [ (0.0, 0.0), diff --git a/tests/sampler/test_SamplerSpacing_featureId.py b/tests/sampler/test_SamplerSpacing_featureId.py index a84df370..6a2820ab 100644 --- a/tests/sampler/test_SamplerSpacing_featureId.py +++ b/tests/sampler/test_SamplerSpacing_featureId.py @@ -1,5 +1,5 @@ import pandas -from map2loop.sampler import SamplerSpacing +from map2loop.sampler import sample_data import shapely import geopandas @@ -9,8 +9,11 @@ sampler_space = 700.0 -sampler = SamplerSpacing(spacing=sampler_space) -geology_samples = sampler.sample(geology_original) +geology_samples = sample_data( + spatial_data = geology_original, + sampler_name = 'spacing', + spacing = sampler_space +) # the actual test: From f284475df4c6a7b527446710d031f2c298c57ca5 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 17 Jun 2025 18:43:22 +0800 Subject: [PATCH 08/14] fix correctly update ci from 3.2.3_dev --- .github/workflows/linting_and_testing.yml | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/.github/workflows/linting_and_testing.yml b/.github/workflows/linting_and_testing.yml index c8c89126..9570a3bb 100644 --- a/.github/workflows/linting_and_testing.yml +++ b/.github/workflows/linting_and_testing.yml @@ -3,11 +3,11 @@ name: Linting and Testing on: push: branches: - - master + - master paths: - '**.py' - .github/workflows/linting_and_testing.yml - + pull_request: branches: - master @@ -50,16 +50,26 @@ jobs: conda-remove-defaults: "true" - - name: Install dependencies + - name: Install dependencies for windows python 3.10 + if: ${{ matrix.os == 'windows-latest' && matrix.python-version == '3.10' }} run: | conda run -n test conda info - conda run -n test conda install -c conda-forge -c loop3d --file dependencies.txt gdal python=${{ matrix.python-version }} pytest -y + conda run -n test conda install -c loop3d -c conda-forge "gdal=3.4.3" python=${{ matrix.python-version }} -y + conda run -n test conda install -c loop3d -c conda-forge --file dependencies.txt python=${{ matrix.python-version }} -y + conda run -n test conda install pytest python=${{ matrix.python-version }} -y + - name: Install dependencies for other environments + if: ${{ matrix.os != 'windows-latest' || matrix.python-version != '3.10' }} + run: | + conda run -n test conda info + conda run -n test conda install -c loop3d -c conda-forge gdal python=${{ matrix.python-version }} -y + conda run -n test conda install -c loop3d -c conda-forge --file dependencies.txt python=${{ matrix.python-version }} -y + conda run -n test conda install pytest python=${{ matrix.python-version }} -y + - name: Install map2loop run: | conda run -n test python -m pip install . - name: Run tests run: | - conda run -n test pytest - + conda run -n test pytest \ No newline at end of file From 5a585e6d7972b933f9e3056f457ea39ff38b34fc Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 17 Jun 2025 19:04:58 +0800 Subject: [PATCH 09/14] update dependencies.txt and map2loop init from 3.2.3_dev --- dependencies.txt | 3 +-- map2loop/__init__.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dependencies.txt b/dependencies.txt index 7a1d95e7..76f0f574 100644 --- a/dependencies.txt +++ b/dependencies.txt @@ -1,10 +1,9 @@ numpy scipy geopandas -shapely +shapely>=2 networkx owslib -map2model loopprojectfile==0.2.2 beartype pytest diff --git a/map2loop/__init__.py b/map2loop/__init__.py index d7ccac11..8723f4ef 100644 --- a/map2loop/__init__.py +++ b/map2loop/__init__.py @@ -30,7 +30,7 @@ class DependencyChecker: def __init__(self, package_name, dependency_file="dependencies.txt"): self.package_name = package_name - self.dependency_file = pathlib.Path(__file__).parent / dependency_file + self.dependency_file = pathlib.Path(__file__).parent.parent / dependency_file self.required_version = self.get_required_version() self.installed_version = self.get_installed_version() @@ -93,7 +93,7 @@ def check_version(self): def check_all_dependencies(dependency_file="dependencies.txt"): - dependencies_path = pathlib.Path(__file__).parent / dependency_file + dependencies_path = pathlib.Path(__file__).parent.parent / dependency_file try: with dependencies_path.open("r") as file: for line in file: @@ -103,6 +103,8 @@ def check_all_dependencies(dependency_file="dependencies.txt"): if line: if "==" in line: package_name, _ = line.split("==") + elif ">=" in line: + package_name, _ = line.split(">=") else: package_name = line From c0d46a595ec67a4351bbf62104960201c74aa427 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Wed, 18 Jun 2025 12:20:33 +0800 Subject: [PATCH 10/14] fix dtm data type and docstring in sampler --- map2loop/sampler.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 4e100782..a513bfe4 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -10,6 +10,7 @@ import numpy from typing import Optional from osgeo import gdal +import osgeo _SAMPLER_REGISTRY = {} @@ -34,6 +35,9 @@ def get_sampler(name: str): Args: name (str): the name of the sampler to retrieve + + Returns: + function: the sampler function """ if name not in _SAMPLER_REGISTRY: raise ValueError(f"Sampler {name} not found") @@ -43,15 +47,19 @@ def get_sampler(name: str): def sample_data( spatial_data: geopandas.GeoDataFrame, sampler_name: str, - dtm_data: Optional[geopandas.GeoDataFrame] = None, + dtm_data: Optional[osgeo.gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None, **kwargs )-> pandas.DataFrame: """ - Execute sampling method (abstract method) + Execute sampling method based on the provided sampler name Args: spatial_data (geopandas.GeoDataFrame): data frame to sample + sampler_name (str): the name of the sampler to use + dtm_data (Optional[osgeo.gdal.Dataset]): dtm data required for decimator sampler + geology_data (Optional[geopandas.GeoDataFrame]): geology data required for decimator sampler + **kwargs: additional arguments to pass to the sampler Returns: pandas.DataFrame: data frame containing samples @@ -68,7 +76,7 @@ def sample_data( @beartype.beartype def sample_decimator( spatial_data: geopandas.GeoDataFrame, - dtm_data: gdal.Dataset, + dtm_data: osgeo.gdal.Dataset, geology_data: geopandas.GeoDataFrame, decimation: int = 1 ) -> pandas.DataFrame: @@ -76,7 +84,10 @@ def sample_decimator( Execute sample method takes full point data, samples the data and returns the decimated points Args: - spatial_data (geopandas.GeoDataFrame): the data frame to sample + spatial_data (geopandas.GeoDataFrame): data frame to sample + dtm_data (osgeo.gdal.Dataset): dtm data + geology_data (geopandas.GeoDataFrame): geology data + decimation (int, optional): the decimation factor. Default to 1 Returns: pandas.DataFrame: the sampled data points @@ -106,6 +117,7 @@ def sample_spacing( Args: spatial_data (geopandas.GeoDataFrame): the data frame to sample (must contain column ["ID"]) + spacing (float, optional): the spacing between points. Default to 50.0 Returns: pandas.DataFrame: the sampled data points From 48cb912ae1567aed3405e92ec6e92c75742abb9a Mon Sep 17 00:00:00 2001 From: noellehmcheng <143368485+noellehmcheng@users.noreply.github.com> Date: Mon, 23 Jun 2025 07:31:57 +0000 Subject: [PATCH 11/14] style: style fixes by ruff and autoformatting by black --- map2loop/project.py | 2 +- map2loop/sampler.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 0e2830dc..a224513b 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -3,7 +3,7 @@ from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData -from .sampler import sample_data, get_sampler +from .sampler import sample_data from .thickness_calculator import InterpolatedStructure, ThicknessCalculator from .throw_calculator import ThrowCalculator, ThrowCalculatorAlpha from .fault_orientation import FaultOrientation diff --git a/map2loop/sampler.py b/map2loop/sampler.py index a513bfe4..0b7cedc6 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -2,14 +2,12 @@ from .utils import set_z_values_from_raster_df # external imports -from abc import ABC, abstractmethod import beartype import geopandas import pandas import shapely import numpy from typing import Optional -from osgeo import gdal import osgeo From 5f3dfb2fbffe32a7f6b78a0a61b95d1a74e03522 Mon Sep 17 00:00:00 2001 From: Rabii Chaarani <50892556+rabii-chaarani@users.noreply.github.com> Date: Wed, 25 Jun 2025 13:14:11 +0930 Subject: [PATCH 12/14] Refactor map2model wrapper as topology functions --- map2loop/map2model_wrapper.py | 190 ---------------------- map2loop/mapdata.py | 11 +- map2loop/project.py | 18 +- map2loop/topology.py | 117 +++++++++++++ tests/topology/__init__.py | 0 tests/topology/test_topology_functions.py | 114 +++++++++++++ 6 files changed, 245 insertions(+), 205 deletions(-) delete mode 100644 map2loop/map2model_wrapper.py create mode 100644 map2loop/topology.py create mode 100644 tests/topology/__init__.py create mode 100644 tests/topology/test_topology_functions.py diff --git a/map2loop/map2model_wrapper.py b/map2loop/map2model_wrapper.py deleted file mode 100644 index 115b8702..00000000 --- a/map2loop/map2model_wrapper.py +++ /dev/null @@ -1,190 +0,0 @@ -# internal imports -from .m2l_enums import VerboseLevel - -# external imports -import geopandas as gpd -import pandas as pd -import numpy as np - -from .logging import getLogger - -logger = getLogger(__name__) - - -class Map2ModelWrapper: - """ - A wrapper around map2model functionality - - Attributes - ---------- - sorted_units: None or list - map2model's estimate of the stratigraphic column - fault_fault_relationships: None or pandas.DataFrame - data frame of fault to fault relationships with columns ["Fault1", "Fault2", "Type", "Angle"] - unit_fault_relationships: None or pandas.DataFrame - data frame of unit fault relationships with columns ["Unit", "Fault"] - unit_unit_relationships: None or pandas.DataFrame - data frame of unit unit relationships with columns ["Index1", "UnitName1", "Index2", "UnitName2"] - map_data: MapData - A pointer to the map data structure in project - verbose_level: m2l_enum.VerboseLevel - A selection that defines how much console logging is output - """ - - def __init__( - self, map_data, *, verbose_level: VerboseLevel = VerboseLevel.NONE - ): - """ - The initialiser for the map2model wrapper - - Args: - map_data (MapData): - The project map data structure to reference - verbose_level (VerboseLevel, optional): - How much console output is sent. Defaults to VerboseLevel.ALL. - """ - self.sorted_units = None - self._fault_fault_relationships = None - self._unit_fault_relationships = None - self._unit_unit_relationships = None - self.map_data = map_data - self.verbose_level = verbose_level - self.buffer_radius = 500 - - @property - def fault_fault_relationships(self): - if self._fault_fault_relationships is None: - self._calculate_fault_fault_relationships() - - return self._fault_fault_relationships - - @property - def unit_fault_relationships(self): - if self._unit_fault_relationships is None: - self._calculate_fault_unit_relationships() - - return self._unit_fault_relationships - - @property - def unit_unit_relationships(self): - if self._unit_unit_relationships is None: - self._calculate_unit_unit_relationships() - - return self._unit_unit_relationships - - def reset(self): - """ - Reset the wrapper to before the map2model process - """ - logger.info("Resetting map2model wrapper") - self.sorted_units = None - self.fault_fault_relationships = None - self.unit_fault_relationships = None - self.unit_unit_relationships = None - - def get_sorted_units(self): - """ - Getter for the map2model sorted units - - Returns: - list: The map2model stratigraphic column estimate - """ - raise NotImplementedError("This method is not implemented") - - - def get_fault_fault_relationships(self): - """ - Getter for the fault fault relationships - - Returns: - pandas.DataFrame: The fault fault relationships - """ - - return self.fault_fault_relationships - - def get_unit_fault_relationships(self): - """ - Getter for the unit fault relationships - - Returns: - pandas.DataFrame: The unit fault relationships - """ - - return self.unit_fault_relationships - - def get_unit_unit_relationships(self): - """ - Getter for the unit unit relationships - - Returns: - pandas.DataFrame: The unit unit relationships - """ - - return self.unit_unit_relationships - - def _calculate_fault_fault_relationships(self): - - faults = self.map_data.FAULT.copy() - # reset index so that we can index the adjacency matrix with the index - faults.reset_index(inplace=True) - buffers = faults.buffer(self.buffer_radius) - # create the adjacency matrix - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=buffers), gpd.GeoDataFrame(geometry=faults["geometry"]) - ) - intersection["index_left"] = intersection.index - intersection.reset_index(inplace=True) - - adjacency_matrix = np.zeros((faults.shape[0], faults.shape[0]), dtype=bool) - adjacency_matrix[ - intersection.loc[:, "index_left"], intersection.loc[:, "index_right"] - ] = True - f1, f2 = np.where(np.tril(adjacency_matrix, k=-1)) - df = pd.DataFrame( - {'Fault1': faults.loc[f1, 'ID'].to_list(), 'Fault2': faults.loc[f2, 'ID'].to_list()} - ) - df['Angle'] = 60 # make it big to prevent LS from making splays - df['Type'] = 'T' - self._fault_fault_relationships = df - - def _calculate_fault_unit_relationships(self): - """Calculate unit/fault relationships using geopandas sjoin. - This will return - """ - units = self.map_data.GEOLOGY["UNITNAME"].unique() - faults = self.map_data.FAULT.copy().reset_index().drop(columns=['index']) - adjacency_matrix = np.zeros((len(units), faults.shape[0]), dtype=bool) - for i, u in enumerate(units): - unit = self.map_data.GEOLOGY[self.map_data.GEOLOGY["UNITNAME"] == u] - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=faults["geometry"]), - gpd.GeoDataFrame(geometry=unit["geometry"]), - ) - intersection["index_left"] = intersection.index - intersection.reset_index(inplace=True) - adjacency_matrix[i, intersection.loc[:, "index_left"]] = True - u, f = np.where(adjacency_matrix) - df = pd.DataFrame({"Unit": units[u].tolist(), "Fault": faults.loc[f, "ID"].to_list()}) - self._unit_fault_relationships = df - - def _calculate_unit_unit_relationships(self): - if self.map_data.contacts is None: - self.map_data.extract_all_contacts() - self._unit_unit_relationships = self.map_data.contacts.copy().drop( - columns=['length', 'geometry'] - ) - return self._unit_unit_relationships - - def run(self, verbose_level: VerboseLevel = None): - """ - The main execute function that prepares, runs and parse the output of the map2model process - - Args: - verbose_level (VerboseLevel, optional): - How much console output is sent. Defaults to None (which uses the wrapper attribute). - """ - - self.get_fault_fault_relationships() - self.get_unit_fault_relationships() - self.get_unit_unit_relationships() - diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index df5f6804..d8ae98c1 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -1360,11 +1360,10 @@ def calculate_bounding_box_and_projection(self): def export_wkt_format_files(self): """ Save out the geology and fault GeoDataFrames in WKT format - This is used by map2model + This is used by topology """ - # TODO: - Move away from tab seperators entirely (topology and map2model) - - self.map2model_tmp_path = pathlib.Path(tempfile.mkdtemp()) + # TODO: - Move away from tab seperators entirely (topology) + self.topology_tmp_path = pathlib.Path(tempfile.mkdtemp()) # Check geology data status and export to a WKT format file self.load_map_data(Datatype.GEOLOGY) @@ -1398,7 +1397,7 @@ def export_wkt_format_files(self): geology["ROCKTYPE1"] = geology["ROCKTYPE1"].replace("", "None") geology["ROCKTYPE2"] = geology["ROCKTYPE2"].replace("", "None") geology.to_csv( - pathlib.Path(self.map2model_tmp_path) / "geology_wkt.csv", sep="\t", index=False + pathlib.Path(self.topology_tmp_path) / "geology_wkt.csv", sep="\t", index=False ) # Check faults data status and export to a WKT format file @@ -1413,7 +1412,7 @@ def export_wkt_format_files(self): faults = self.get_map_data(Datatype.FAULT).copy() faults.rename(columns={"geometry": "WKT"}, inplace=True) faults.to_csv( - pathlib.Path(self.map2model_tmp_path) / "faults_wkt.csv", sep="\t", index=False + pathlib.Path(self.topology_tmp_path) / "faults_wkt.csv", sep="\t", index=False ) @beartype.beartype diff --git a/map2loop/project.py b/map2loop/project.py index a224513b..ab122d5f 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -10,7 +10,7 @@ from .sorter import Sorter, SorterAgeBased, SorterAlpha, SorterUseNetworkX, SorterUseHint from .stratigraphic_column import StratigraphicColumn from .deformation_history import DeformationHistory -from .map2model_wrapper import Map2ModelWrapper +from .topology import run_topology, calculate_unit_unit_relationships from .data_checks import validate_config_dictionary # external imports @@ -47,8 +47,8 @@ class Project(object): The name of the loop project file used in this project map_data: MapData The structure that holds all map and dtm data - map2model: Map2ModelWrapper - A wrapper around the map2model module that extracts unit and fault adjacency + topology_results: dict + A dictionary storing results from the topology process stratigraphic_column: StratigraphicColumn The structure that holds the unit information and ordering deformation_history: DeformationHistory @@ -143,7 +143,7 @@ def __init__( self.throw_calculator = ThrowCalculatorAlpha() self.fault_orientation = FaultOrientationNearest() self.map_data = MapData(verbose_level=verbose_level) - self.map2model = Map2ModelWrapper(self.map_data) + self.topology_results = None self.stratigraphic_column = StratigraphicColumn() self.deformation_history = DeformationHistory(project=self) self.loop_filename = loop_project_filename @@ -583,7 +583,7 @@ def calculate_stratigraphic_order(self, take_best=False): columns = [ sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.map2model.get_unit_unit_relationships(), + calculate_unit_unit_relationships(self.map_data), self.map_data.contacts, self.map_data, ) @@ -613,7 +613,7 @@ def calculate_stratigraphic_order(self, take_best=False): logger.info(f'Calculating stratigraphic column using sorter {self.sorter.sorter_label}') self.stratigraphic_column.column = self.sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.map2model.get_unit_unit_relationships(), + calculate_unit_unit_relationships(self.map_data), self.map_data.contacts, self.map_data, ) @@ -802,7 +802,7 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): self.stratigraphic_column.column = user_defined_stratigraphic_column - self.map2model.run() # if we use a user defined stratigraphic column, we still need to calculate the results of map2model + self.topology_results = run_topology(self.map_data) else: if user_defined_stratigraphic_column is not None: logger.warning( @@ -1050,9 +1050,9 @@ def save_into_projectfile(self): observations["dipPolarity"] = self.structure_samples["OVERTURNED"] LPF.Set(self.loop_filename, "stratigraphicObservations", data=observations) - if self.map2model.fault_fault_relationships is not None: + if self.topology_results and self.topology_results.get("fault_fault_relationships") is not None: ff_relationships = self.deformation_history.get_fault_relationships_with_ids( - self.map2model.fault_fault_relationships + self.topology_results["fault_fault_relationships"] ) relationships = numpy.zeros(len(ff_relationships), LPF.eventRelationshipType) diff --git a/map2loop/topology.py b/map2loop/topology.py new file mode 100644 index 00000000..99faaaab --- /dev/null +++ b/map2loop/topology.py @@ -0,0 +1,117 @@ +# internal imports +from .mapdata import MapData +from .logging import getLogger + +# external imports +import geopandas as gpd +import pandas as pd +import numpy as np +import beartype +from beartype.typing import Dict + +logger = getLogger(__name__) + +_TOPOLOGY_REGISTRY = {} + +@beartype.beartype +def calculate_fault_fault_relationships( + map_data: MapData, + buffer_radius: float = 500, +) -> pd.DataFrame: + """Calculate fault to fault relationships.""" + faults = map_data.FAULT.copy() + faults.reset_index(inplace=True) + buffers = faults.buffer(buffer_radius) + intersection = gpd.sjoin( + gpd.GeoDataFrame(geometry=buffers), + gpd.GeoDataFrame(geometry=faults["geometry"]), + ) + intersection["index_left"] = intersection.index + intersection.reset_index(inplace=True) + + adjacency_matrix = np.zeros((faults.shape[0], faults.shape[0]), dtype=bool) + adjacency_matrix[ + intersection.loc[:, "index_left"], + intersection.loc[:, "index_right"], + ] = True + f1, f2 = np.where(np.tril(adjacency_matrix, k=-1)) + df = pd.DataFrame( + { + "Fault1": faults.loc[f1, "ID"].to_list(), + "Fault2": faults.loc[f2, "ID"].to_list(), + } + ) + df["Angle"] = 60 + df["Type"] = "T" + return df + +@beartype.beartype +def calculate_unit_fault_relationships( + map_data: MapData, + buffer_radius: float = 500, +) -> pd.DataFrame: + """Calculate unit to fault relationships.""" + units = map_data.GEOLOGY["UNITNAME"].unique() + faults = map_data.FAULT.copy().reset_index().drop(columns=["index"]) + adjacency_matrix = np.zeros((len(units), faults.shape[0]), dtype=bool) + for i, u in enumerate(units): + unit = map_data.GEOLOGY[map_data.GEOLOGY["UNITNAME"] == u] + intersection = gpd.sjoin( + gpd.GeoDataFrame(geometry=faults["geometry"]), + gpd.GeoDataFrame(geometry=unit["geometry"]), + ) + intersection["index_left"] = intersection.index + intersection.reset_index(inplace=True) + adjacency_matrix[i, intersection.loc[:, "index_left"]] = True + u_idx, f_idx = np.where(adjacency_matrix) + df = pd.DataFrame({"Unit": units[u_idx].tolist(), "Fault": faults.loc[f_idx, "ID"].to_list()}) + return df + +@beartype.beartype +def calculate_unit_unit_relationships(map_data: MapData) -> pd.DataFrame: + """Calculate unit to unit relationships.""" + if map_data.contacts is None: + map_data.extract_all_contacts() + return map_data.contacts.copy().drop(columns=["length", "geometry"]) + +@beartype.beartype +def register_topology(name: str): + """Register a topology function with a given name.""" + def decorator(func): + _TOPOLOGY_REGISTRY[name] = func + return func + return decorator + +@beartype.beartype +def get_topology(name: str): + """Retrieve a registered topology function.""" + if name not in _TOPOLOGY_REGISTRY: + raise ValueError(f"Topology {name} not found") + return _TOPOLOGY_REGISTRY[name] + +@beartype.beartype +def run_topology( + map_data: MapData, + topology_name: str = "default", + **kwargs, +) -> Dict[str, pd.DataFrame]: + """Execute a topology function by name.""" + runner = get_topology(topology_name) + return runner(map_data=map_data, **kwargs) + +@register_topology("default") +@beartype.beartype +def topology_default( + map_data: MapData, + buffer_radius: float = 500, +) -> Dict[str, pd.DataFrame]: + """Calculate topology relationships using basic geopandas operations.""" + ff_df = calculate_fault_fault_relationships(map_data, buffer_radius) + uf_df = calculate_unit_fault_relationships(map_data, buffer_radius) + uu_df = calculate_unit_unit_relationships(map_data) + + return { + "fault_fault_relationships": ff_df, + "unit_fault_relationships": uf_df, + "unit_unit_relationships": uu_df, + } diff --git a/tests/topology/__init__.py b/tests/topology/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/topology/test_topology_functions.py b/tests/topology/test_topology_functions.py new file mode 100644 index 00000000..c8fb1b5e --- /dev/null +++ b/tests/topology/test_topology_functions.py @@ -0,0 +1,114 @@ +import sys +import types + +osgeo_stub = types.ModuleType("osgeo") +osgeo_stub.gdal = types.ModuleType("gdal") +osgeo_stub.osr = types.ModuleType("osr") +def _noop(): + pass +osgeo_stub.gdal.UseExceptions = _noop +class _Dataset: + def GetGeoTransform(self): + return (0, 1, 0, 0, 0, 1) +osgeo_stub.gdal.Dataset = _Dataset +def InvGeoTransform(gt): + return gt +osgeo_stub.gdal.InvGeoTransform = InvGeoTransform +sys.modules.setdefault("osgeo", osgeo_stub) +sys.modules.setdefault("osgeo.gdal", osgeo_stub.gdal) +sys.modules.setdefault("osgeo.osr", osgeo_stub.osr) + +import geopandas as gpd +from shapely.geometry import LineString, Polygon +import pandas as pd + +from map2loop.mapdata import MapData +from map2loop.m2l_enums import Datatype, Datastate +from map2loop.topology import ( + calculate_fault_fault_relationships, + calculate_unit_fault_relationships, + calculate_unit_unit_relationships, + register_topology, + run_topology, +) + + +def _create_basic_mapdata(): + md = MapData() + faults = gpd.GeoDataFrame( + { + "geometry": [ + LineString([(0, 0), (1, 0)]), + LineString([(0, 0), (0, 1)]), + ], + "ID": ["F1", "F2"], + }, + geometry="geometry", + crs="EPSG:4326", + ) + md.data[Datatype.FAULT] = faults + md.data_states[Datatype.FAULT] = Datastate.COMPLETE + md.dirtyflags[Datatype.FAULT] = False + + geology = gpd.GeoDataFrame( + { + "UNITNAME": ["U1", "U2"], + "geometry": [ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]), + ], + }, + geometry="geometry", + crs="EPSG:4326", + ) + md.data[Datatype.GEOLOGY] = geology + md.data_states[Datatype.GEOLOGY] = Datastate.COMPLETE + md.dirtyflags[Datatype.GEOLOGY] = False + + contacts = pd.DataFrame( + { + "UNITNAME_1": ["U1"], + "UNITNAME_2": ["U2"], + "length": [1.0], + "geometry": [LineString([(1, 0), (1, 1)])], + } + ) + md.contacts = contacts + return md + + +def test_calculate_fault_fault_relationships(): + md = _create_basic_mapdata() + df = calculate_fault_fault_relationships(md, buffer_radius=0.1) + assert len(df) == 1 + assert set(df.iloc[0][["Fault1", "Fault2"]]) == {"F1", "F2"} + + +def test_calculate_unit_fault_relationships(): + md = _create_basic_mapdata() + df = calculate_unit_fault_relationships(md, buffer_radius=0.1) + pairs = set(tuple(row) for row in df[["Unit", "Fault"]].to_records(index=False)) + assert pairs == {("U1", "F1"), ("U1", "F2"), ("U2", "F1")} + + +def test_calculate_unit_unit_relationships(): + md = _create_basic_mapdata() + df = calculate_unit_unit_relationships(md) + assert list(df.columns) == ["UNITNAME_1", "UNITNAME_2"] + assert df.iloc[0]["UNITNAME_1"] == "U1" + assert df.iloc[0]["UNITNAME_2"] == "U2" + + +def test_registry_and_runner(): + called = {} + + @register_topology("test") + def run_test(map_data: MapData): + called["executed"] = True + return {"dummy": pd.DataFrame()} + + md = _create_basic_mapdata() + result = run_topology(md, "test") + assert "executed" in called + assert list(result.keys()) == ["dummy"] + assert isinstance(result["dummy"], pd.DataFrame) From c1506a5c132c954d29cd37fb611d0463beef71cb Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 25 Jun 2025 13:25:20 +0930 Subject: [PATCH 13/14] fix: add topology registration and retrieval functions --- map2loop/topology.py | 34 ++++++++++++----------- tests/topology/test_topology_functions.py | 2 +- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/map2loop/topology.py b/map2loop/topology.py index 99faaaab..fcb4734a 100644 --- a/map2loop/topology.py +++ b/map2loop/topology.py @@ -13,6 +13,22 @@ _TOPOLOGY_REGISTRY = {} +@beartype.beartype +def register_topology(name: str): + """Register a topology function with a given name.""" + def decorator(func): + _TOPOLOGY_REGISTRY[name] = func + return func + return decorator + +@beartype.beartype +def get_topology(name: str): + """Retrieve a registered topology function.""" + if name not in _TOPOLOGY_REGISTRY: + raise ValueError(f"Topology {name} not found") + return _TOPOLOGY_REGISTRY[name] + +@register_topology("fault_fault_relationships") @beartype.beartype def calculate_fault_fault_relationships( map_data: MapData, @@ -45,6 +61,7 @@ def calculate_fault_fault_relationships( df["Type"] = "T" return df +@register_topology("unit_fault_relationships") @beartype.beartype def calculate_unit_fault_relationships( map_data: MapData, @@ -66,7 +83,7 @@ def calculate_unit_fault_relationships( u_idx, f_idx = np.where(adjacency_matrix) df = pd.DataFrame({"Unit": units[u_idx].tolist(), "Fault": faults.loc[f_idx, "ID"].to_list()}) return df - +@register_topology("unit_unit_relationships") @beartype.beartype def calculate_unit_unit_relationships(map_data: MapData) -> pd.DataFrame: """Calculate unit to unit relationships.""" @@ -74,21 +91,6 @@ def calculate_unit_unit_relationships(map_data: MapData) -> pd.DataFrame: map_data.extract_all_contacts() return map_data.contacts.copy().drop(columns=["length", "geometry"]) -@beartype.beartype -def register_topology(name: str): - """Register a topology function with a given name.""" - def decorator(func): - _TOPOLOGY_REGISTRY[name] = func - return func - return decorator - -@beartype.beartype -def get_topology(name: str): - """Retrieve a registered topology function.""" - if name not in _TOPOLOGY_REGISTRY: - raise ValueError(f"Topology {name} not found") - return _TOPOLOGY_REGISTRY[name] - @beartype.beartype def run_topology( map_data: MapData, diff --git a/tests/topology/test_topology_functions.py b/tests/topology/test_topology_functions.py index c8fb1b5e..44f77ba3 100644 --- a/tests/topology/test_topology_functions.py +++ b/tests/topology/test_topology_functions.py @@ -87,7 +87,7 @@ def test_calculate_fault_fault_relationships(): def test_calculate_unit_fault_relationships(): md = _create_basic_mapdata() df = calculate_unit_fault_relationships(md, buffer_radius=0.1) - pairs = set(tuple(row) for row in df[["Unit", "Fault"]].to_records(index=False)) + pairs = {tuple(row) for row in df[["Unit", "Fault"]].to_records(index=False)} assert pairs == {("U1", "F1"), ("U1", "F2"), ("U2", "F1")} From b8e02c28990de4ac188013ad8a1f18a8b4ed060d Mon Sep 17 00:00:00 2001 From: Rabii Chaarani <50892556+rabii-chaarani@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:44:51 +0930 Subject: [PATCH 14/14] Refactor topology API to use GeoDataFrames (#215) * refactor: update topology functions to use gdf layers * refactor: update imports and function signatures --- map2loop/project.py | 5 +- map2loop/topology.py | 85 ++++++++++++++--------- tests/topology/test_topology_functions.py | 41 +++-------- 3 files changed, 66 insertions(+), 65 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index ab122d5f..0c01eeb5 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -802,7 +802,10 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): self.stratigraphic_column.column = user_defined_stratigraphic_column - self.topology_results = run_topology(self.map_data) + self.topology_results = run_topology( + self.map_data.FAULT, + self.map_data.GEOLOGY, + ) else: if user_defined_stratigraphic_column is not None: logger.warning( diff --git a/map2loop/topology.py b/map2loop/topology.py index fcb4734a..724f59de 100644 --- a/map2loop/topology.py +++ b/map2loop/topology.py @@ -1,11 +1,12 @@ # internal imports -from .mapdata import MapData from .logging import getLogger +from map2loop.map_data import MapData # external imports -import geopandas as gpd -import pandas as pd -import numpy as np +import inspect +import geopandas +import pandas +import numpy import beartype from beartype.typing import Dict @@ -31,27 +32,27 @@ def get_topology(name: str): @register_topology("fault_fault_relationships") @beartype.beartype def calculate_fault_fault_relationships( - map_data: MapData, + fault_layer: geopandas.GeoDataFrame, buffer_radius: float = 500, -) -> pd.DataFrame: +) -> pandas.DataFrame: """Calculate fault to fault relationships.""" - faults = map_data.FAULT.copy() + faults = fault_layer.copy() faults.reset_index(inplace=True) buffers = faults.buffer(buffer_radius) - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=buffers), - gpd.GeoDataFrame(geometry=faults["geometry"]), + intersection = geopandas.sjoin( + geopandas.GeoDataFrame(geometry=buffers), + geopandas.GeoDataFrame(geometry=faults["geometry"]), ) intersection["index_left"] = intersection.index intersection.reset_index(inplace=True) - adjacency_matrix = np.zeros((faults.shape[0], faults.shape[0]), dtype=bool) + adjacency_matrix = numpy.zeros((faults.shape[0], faults.shape[0]), dtype=bool) adjacency_matrix[ intersection.loc[:, "index_left"], intersection.loc[:, "index_right"], ] = True - f1, f2 = np.where(np.tril(adjacency_matrix, k=-1)) - df = pd.DataFrame( + f1, f2 = numpy.where(numpy.tril(adjacency_matrix, k=-1)) + df = pandas.DataFrame( { "Fault1": faults.loc[f1, "ID"].to_list(), "Fault2": faults.loc[f2, "ID"].to_list(), @@ -64,53 +65,69 @@ def calculate_fault_fault_relationships( @register_topology("unit_fault_relationships") @beartype.beartype def calculate_unit_fault_relationships( - map_data: MapData, + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, buffer_radius: float = 500, -) -> pd.DataFrame: +) -> pandas.DataFrame: """Calculate unit to fault relationships.""" - units = map_data.GEOLOGY["UNITNAME"].unique() - faults = map_data.FAULT.copy().reset_index().drop(columns=["index"]) - adjacency_matrix = np.zeros((len(units), faults.shape[0]), dtype=bool) + units = geology_layer["UNITNAME"].unique() + faults = fault_layer.copy().reset_index().drop(columns=["index"]) + adjacency_matrix = numpy.zeros((len(units), faults.shape[0]), dtype=bool) for i, u in enumerate(units): - unit = map_data.GEOLOGY[map_data.GEOLOGY["UNITNAME"] == u] - intersection = gpd.sjoin( - gpd.GeoDataFrame(geometry=faults["geometry"]), - gpd.GeoDataFrame(geometry=unit["geometry"]), + unit = geology_layer[geology_layer["UNITNAME"] == u] + intersection = geopandas.sjoin( + geopandas.GeoDataFrame(geometry=faults["geometry"]), + geopandas.GeoDataFrame(geometry=unit["geometry"]), ) intersection["index_left"] = intersection.index intersection.reset_index(inplace=True) adjacency_matrix[i, intersection.loc[:, "index_left"]] = True - u_idx, f_idx = np.where(adjacency_matrix) - df = pd.DataFrame({"Unit": units[u_idx].tolist(), "Fault": faults.loc[f_idx, "ID"].to_list()}) + u_idx, f_idx = numpy.where(adjacency_matrix) + df = pandas.DataFrame({"Unit": units[u_idx].tolist(), "Fault": faults.loc[f_idx, "ID"].to_list()}) return df @register_topology("unit_unit_relationships") @beartype.beartype -def calculate_unit_unit_relationships(map_data: MapData) -> pd.DataFrame: +def calculate_unit_unit_relationships( + geology_layer: geopandas.GeoDataFrame, + contacts: pandas.DataFrame = None, + map_data: MapData = None, +) -> pandas.DataFrame: """Calculate unit to unit relationships.""" - if map_data.contacts is None: + if contacts is None: map_data.extract_all_contacts() return map_data.contacts.copy().drop(columns=["length", "geometry"]) @beartype.beartype def run_topology( - map_data: MapData, + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, topology_name: str = "default", **kwargs, -) -> Dict[str, pd.DataFrame]: +) -> Dict[str, pandas.DataFrame]: """Execute a topology function by name.""" runner = get_topology(topology_name) - return runner(map_data=map_data, **kwargs) + signature = inspect.signature(runner) + call_args = {} + if "fault_layer" in signature.parameters: + call_args["fault_layer"] = fault_layer + if "geology_layer" in signature.parameters: + call_args["geology_layer"] = geology_layer + call_args.update(kwargs) + return runner(**call_args) @register_topology("default") @beartype.beartype def topology_default( - map_data: MapData, + fault_layer: geopandas.GeoDataFrame, + geology_layer: geopandas.GeoDataFrame, + contacts: pandas.DataFrame = None, + map_data: MapData = None, buffer_radius: float = 500, -) -> Dict[str, pd.DataFrame]: +) -> Dict[str, pandas.DataFrame]: """Calculate topology relationships using basic geopandas operations.""" - ff_df = calculate_fault_fault_relationships(map_data, buffer_radius) - uf_df = calculate_unit_fault_relationships(map_data, buffer_radius) - uu_df = calculate_unit_unit_relationships(map_data) + ff_df = calculate_fault_fault_relationships(fault_layer, buffer_radius) + uf_df = calculate_unit_fault_relationships(fault_layer, geology_layer, buffer_radius) + uu_df = calculate_unit_unit_relationships(geology_layer, contacts, map_data) return { "fault_fault_relationships": ff_df, diff --git a/tests/topology/test_topology_functions.py b/tests/topology/test_topology_functions.py index 44f77ba3..b0a670c5 100644 --- a/tests/topology/test_topology_functions.py +++ b/tests/topology/test_topology_functions.py @@ -21,9 +21,6 @@ def InvGeoTransform(gt): import geopandas as gpd from shapely.geometry import LineString, Polygon import pandas as pd - -from map2loop.mapdata import MapData -from map2loop.m2l_enums import Datatype, Datastate from map2loop.topology import ( calculate_fault_fault_relationships, calculate_unit_fault_relationships, @@ -33,8 +30,7 @@ def InvGeoTransform(gt): ) -def _create_basic_mapdata(): - md = MapData() +def _create_basic_layers(): faults = gpd.GeoDataFrame( { "geometry": [ @@ -46,9 +42,6 @@ def _create_basic_mapdata(): geometry="geometry", crs="EPSG:4326", ) - md.data[Datatype.FAULT] = faults - md.data_states[Datatype.FAULT] = Datastate.COMPLETE - md.dirtyflags[Datatype.FAULT] = False geology = gpd.GeoDataFrame( { @@ -61,39 +54,27 @@ def _create_basic_mapdata(): geometry="geometry", crs="EPSG:4326", ) - md.data[Datatype.GEOLOGY] = geology - md.data_states[Datatype.GEOLOGY] = Datastate.COMPLETE - md.dirtyflags[Datatype.GEOLOGY] = False - contacts = pd.DataFrame( - { - "UNITNAME_1": ["U1"], - "UNITNAME_2": ["U2"], - "length": [1.0], - "geometry": [LineString([(1, 0), (1, 1)])], - } - ) - md.contacts = contacts - return md + return faults, geology def test_calculate_fault_fault_relationships(): - md = _create_basic_mapdata() - df = calculate_fault_fault_relationships(md, buffer_radius=0.1) + faults, geology = _create_basic_layers() + df = calculate_fault_fault_relationships(faults, buffer_radius=0.1) assert len(df) == 1 assert set(df.iloc[0][["Fault1", "Fault2"]]) == {"F1", "F2"} def test_calculate_unit_fault_relationships(): - md = _create_basic_mapdata() - df = calculate_unit_fault_relationships(md, buffer_radius=0.1) + faults, geology = _create_basic_layers() + df = calculate_unit_fault_relationships(faults, geology, buffer_radius=0.1) pairs = {tuple(row) for row in df[["Unit", "Fault"]].to_records(index=False)} assert pairs == {("U1", "F1"), ("U1", "F2"), ("U2", "F1")} def test_calculate_unit_unit_relationships(): - md = _create_basic_mapdata() - df = calculate_unit_unit_relationships(md) + faults, geology = _create_basic_layers() + df = calculate_unit_unit_relationships(geology) assert list(df.columns) == ["UNITNAME_1", "UNITNAME_2"] assert df.iloc[0]["UNITNAME_1"] == "U1" assert df.iloc[0]["UNITNAME_2"] == "U2" @@ -103,12 +84,12 @@ def test_registry_and_runner(): called = {} @register_topology("test") - def run_test(map_data: MapData): + def run_test(fault_layer: gpd.GeoDataFrame, geology_layer: gpd.GeoDataFrame): called["executed"] = True return {"dummy": pd.DataFrame()} - md = _create_basic_mapdata() - result = run_topology(md, "test") + faults, geology = _create_basic_layers() + result = run_topology(faults, geology, "test") assert "executed" in called assert list(result.keys()) == ["dummy"] assert isinstance(result["dummy"], pd.DataFrame)