From 1821f6d88bb1774ca85a2529a55906fcb4e870ad Mon Sep 17 00:00:00 2001 From: Githubcopilot111 Date: Fri, 11 Oct 2024 02:35:54 +0800 Subject: [PATCH 1/6] Use shapely to match centroid to catchment polygon --- MANIFEST.in | 2 + build.bat | 61 ++++++ installer_osgeo4w.bat | 41 ++++ pyproject.toml | 27 ++- setup.py | 29 +++ src/app.py | 186 +++++++++++++++--- src/app_testing.py | 67 +++++++ src/pyromb/core/attributes/reach.py | 48 +++-- .../core/geometry/shapefile_validation.py | 148 ++++++++++++++ src/pyromb/core/gis/builder.py | 81 +++++--- src/pyromb/resources/expected_fields.json | 37 ++++ 11 files changed, 642 insertions(+), 85 deletions(-) create mode 100644 MANIFEST.in create mode 100644 build.bat create mode 100644 installer_osgeo4w.bat create mode 100644 setup.py create mode 100644 src/app_testing.py create mode 100644 src/pyromb/core/geometry/shapefile_validation.py create mode 100644 src/pyromb/resources/expected_fields.json diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..d653aef --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +# Include all JSON files in the resources directory and its subdirectories +recursive-include src/pyromb/resources *.json diff --git a/build.bat b/build.bat new file mode 100644 index 0000000..3e30280 --- /dev/null +++ b/build.bat @@ -0,0 +1,61 @@ +@echo off +setlocal + +REM Verify Python and Pip paths +echo Using Python from: +where python +echo Using pip from: +where pip + +REM Check setuptools version +python -m pip show setuptools + +REM Upgrade setuptools and wheel +echo Upgrading setuptools and wheel... +python -m pip install --upgrade setuptools wheel + +REM Define the directory where the package will be stored +set "PACKAGE_DIR=%~dp0dist" +echo Package directory: %PACKAGE_DIR% + +REM Navigate to the directory containing the setup.py script +cd /d "%~dp0" + +REM Clean previous builds +if exist "%PACKAGE_DIR%" ( + echo Cleaning previous builds... + rmdir /s /q "%PACKAGE_DIR%" +) + +REM Build the source distribution and wheel using setuptools +echo Building the package using setuptools... +python setup.py sdist bdist_wheel + +REM Check if the build was successful +if %ERRORLEVEL% neq 0 ( + echo Build failed. Please check the setup.py and pyproject.toml for errors. + endlocal + pause + goto :EOF +) + +REM Create the dist directory if it doesn't exist +if not exist "%PACKAGE_DIR%" mkdir "%PACKAGE_DIR%" + +REM Move the generated .tar.gz and .whl files to the desired folder +echo Moving built packages to %PACKAGE_DIR%... +move /Y "dist\*.tar.gz" "%PACKAGE_DIR%" >nul +move /Y "dist\*.whl" "%PACKAGE_DIR%" >nul + +REM Check if the move was successful +if %ERRORLEVEL% neq 0 ( + echo Failed to move the package to the destination folder. + endlocal + pause + goto :EOF +) + +echo Package created and moved to %PACKAGE_DIR% successfully. +endlocal +pause +goto :EOF diff --git a/installer_osgeo4w.bat b/installer_osgeo4w.bat new file mode 100644 index 0000000..1afab18 --- /dev/null +++ b/installer_osgeo4w.bat @@ -0,0 +1,41 @@ +@echo off +setlocal + +REM THIS DIDN'T SEEM TO WORK, LIKELY USER ERROR. + +REM Activate OSGeo4W environment +call "C:\OSGEO4W\bin\o4w_env.bat" + +REM Define the directory where the built package is stored +set "PACKAGE_DIR=%~dp0dist" +echo Package directory: %PACKAGE_DIR% + +REM Find the latest version of the package +for /f "delims=" %%i in ('dir /b /o-n "%PACKAGE_DIR%\pyromb-*.whl"') do ( + set "LATEST_PACKAGE=%%i" + goto found +) + +:found +if "%LATEST_PACKAGE%"=="" ( + echo No package found in the directory. + pause + endlocal + goto :EOF +) + +echo Installing or updating "%LATEST_PACKAGE%" + +REM Install or update the package using pip from the OSGeo4W environment +pip install --upgrade "%PACKAGE_DIR%\%LATEST_PACKAGE%" + +REM Check if the installation was successful +if %ERRORLEVEL% equ 0 ( + echo Installation completed successfully. +) else ( + echo Installation failed. Please check the path and try again. +) + +endlocal +pause +goto :EOF diff --git a/pyproject.toml b/pyproject.toml index bd67230..74012e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,26 +1,25 @@ [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src/pyromb"] - -[tool.hatch.build] -only-packages = true -exclude = [ - ".conda", -] +requires = ["setuptools>=61.0.0", "wheel"] +build-backend = "setuptools.build_meta" [project] name = "pyromb" -version = "0.2.0" +version = "0.2.2" authors = [ - { name="Tom Norman", email="tom@normcosystems.com" } + { name = "Tom Norman", email = "tom@normcosystems.com" } ] description = "Runoff Model Builder (Pyromb) is a package used for building RORB and WBNM control files from catchment diagrams built from ESRI shapefiles. Its primary use is in the QGIS plugin Runoff Model: RORB and Runoff Model: WBNM" readme = "README.md" requires-python = ">=3.9" +dependencies = [ + "shapely", + "pyshp", + "matplotlib", + "numpy", + # Add any additional dependencies here +] + classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -29,4 +28,4 @@ classifiers = [ [project.urls] "Homepage" = "https://github.com/norman-tom/pyromb" -"Bug Tracker" = "https://github.com/norman-tom/pyromb/issues" \ No newline at end of file +"Bug Tracker" = "https://github.com/norman-tom/pyromb/issues" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b362768 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup, find_packages +import os + +# Read the long description from README.md +current_dir = os.path.abspath(os.path.dirname(__file__)) +with open(os.path.join(current_dir, "README.md"), "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="pyromb", + version="0.2.1", + packages=find_packages( + where="src", exclude=["*.tests", "*.tests.*", "tests.*", "tests"] + ), + package_dir={"": "src"}, + author="Tom Norman", + author_email="tom@normcosystems.com", + description="Runoff Model Builder (Pyromb) is a package used for building RORB and WBNM control files from catchment diagrams built from ESRI shapefiles. Its primary use is in the QGIS plugin Runoff Model: RORB and Runoff Model: WBNM", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/norman-tom/pyromb", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.12", + include_package_data=True, # Ensures inclusion of files specified in MANIFEST.in +) diff --git a/src/app.py b/src/app.py index db1ecb9..a172286 100644 --- a/src/app.py +++ b/src/app.py @@ -1,62 +1,196 @@ +# app.py import os import pyromb from plot_catchment import plot_catchment import shapefile as sf +from shapely.geometry import shape as shapely_shape +import logging +import json +from typing import Any +from pyromb.core.geometry.shapefile_validation import ( + validate_shapefile_fields, + validate_shapefile_geometries, + validate_confluences_out_field, +) + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") DIR = os.path.dirname(__file__) -REACH_PATH = os.path.join(DIR, '../data', 'reaches.shp') -BASIN_PATH = os.path.join(DIR, '../data', 'basins.shp') -CENTROID_PATH = os.path.join(DIR, '../data', 'centroids.shp') -CONFUL_PATH = os.path.join(DIR, '../data', 'confluences.shp') +REACH_PATH = os.path.join(DIR, "../data", "reaches.shp") +BASIN_PATH = os.path.join(DIR, "../data", "basins.shp") +CENTROID_PATH = os.path.join(DIR, "../data", "centroids.shp") +CONFUL_PATH = os.path.join(DIR, "../data", "confluences.shp") + +# Load expected fields from JSON file +with open(os.path.join(DIR, r"pyromb\resources", r"expected_fields.json"), "r") as f: + EXPECTED_FIELDS_JSON = json.load(f) + +# Convert JSON to the required dictionary format +EXPECTED_FIELDS = { + key: [(field["name"], field["type"]) for field in fields] + for key, fields in EXPECTED_FIELDS_JSON.items() +} + class SFVectorLayer(sf.Reader, pyromb.VectorLayer): """ Wrap the shapefile.Reader() with the necessary interface - to work with the builder. + to work with the builder. """ + def __init__(self, path) -> None: super().__init__(path) + # Extract field names, skipping the first DeletionFlag field + self.field_names = [field[0] for field in self.fields[1:]] + # Precompute Shapely geometries for all shapes + self.shapely_geometries = [ + shapely_shape(self.shape(i).__geo_interface__) for i in range(len(self)) + ] def geometry(self, i) -> list: return self.shape(i).points - + + def shapely_geometry(self, i): + """ + Return the Shapely geometry for the ith shape. + """ + return self.shapely_geometries[i] + def record(self, i) -> dict: - return super().record(i) - + """ + Return a dictionary mapping field names to their corresponding values. + """ + rec = super().record(i) + return dict(zip(self.field_names, rec)) + def __len__(self) -> int: return super().__len__() -def main(): - ### Config ### - plot = False # Set True of you want the catchment to be plotted - model = pyromb.RORB() # Select your hydrology model, either pyromb.RORB() or pyromb.WBNM() + +def main( + reach_path: str | None = None, + basin_path: str | None = None, + centroid_path: str | None = None, + confluence_path: str | None = None, + output_name: str | None = None, + plot: bool = False, + model: Any | None = None, +) -> None: + """ + Main function to build and process catchment data. + + Parameters + ---------- + reach_path : str + Path to the reaches shapefile. + basin_path : str + Path to the basins shapefile. + centroid_path : str + Path to the centroids shapefile. + confluence_path : str + Path to the confluences shapefile. + output_name : str + Name of the output file. + plot : bool + Whether to plot the catchment. + model : pyromb.Model + The hydrology model to use. + """ + # Set default paths if not provided + reach_path = reach_path or REACH_PATH + basin_path = basin_path or BASIN_PATH + centroid_path = centroid_path or CENTROID_PATH + confluence_path = confluence_path or CONFUL_PATH + model = model or pyromb.RORB() + if isinstance(model, pyromb.RORB): + output_name = output_name or os.path.join(DIR, "../vector.catg") + else: + output_name = output_name or os.path.join(DIR, "../runfile.wbnm") + model = model or pyromb.RORB() ### Build Catchment Objects ### - # Vector layers - reach_vector = SFVectorLayer(REACH_PATH) - basin_vector = SFVectorLayer(BASIN_PATH) - centroid_vector = SFVectorLayer(CENTROID_PATH) - confluence_vector = SFVectorLayer(CONFUL_PATH) - # Create the builder. + # Vector layers + reach_vector = SFVectorLayer(reach_path) + basin_vector = SFVectorLayer(basin_path) + centroid_vector = SFVectorLayer(centroid_path) + confluence_vector = SFVectorLayer(confluence_path) + + # Validate shapefile fields + validation_reaches = validate_shapefile_fields( + reach_vector, "Reaches", EXPECTED_FIELDS["reaches"] + ) + validation_basins = validate_shapefile_fields( + basin_vector, "Basins", EXPECTED_FIELDS["basins"] + ) + validation_centroids = validate_shapefile_fields( + centroid_vector, "Centroids", EXPECTED_FIELDS["centroids"] + ) + validation_confluences = validate_shapefile_fields( + confluence_vector, "Confluences", EXPECTED_FIELDS["confluences"] + ) + + validate_confluences_out = validate_confluences_out_field( + confluence_vector, "Confluences" + ) + + # Validate shapefile geometries + validation_geometries_reaches = validate_shapefile_geometries( + reach_vector, "Reaches" + ) + validation_geometries_basins = validate_shapefile_geometries(basin_vector, "Basins") + validation_geometries_centroids = validate_shapefile_geometries( + centroid_vector, "Centroids" + ) + validation_geometries_confluences = validate_shapefile_geometries( + confluence_vector, "Confluences" + ) + + # Decide whether to proceed based on validation + # Decide whether to proceed based on validation + if not all( + [ + validation_reaches, + validation_basins, + validation_centroids, + validation_confluences, + validate_confluences_out, + validation_geometries_reaches, + validation_geometries_basins, + validation_geometries_centroids, + validation_geometries_confluences, + ] + ): + logging.warning( + "One or more shapefiles failed validation. Proceeding with caution." + ) + else: + print("Shapefiles passed initial validation check.") + + # Create the builder. builder = pyromb.Builder() # Build each element as per the vector layer. tr = builder.reach(reach_vector) tc = builder.confluence(confluence_vector) tb = builder.basin(centroid_vector, basin_vector) - - ### Create the catchment ### + + ### Create the catchment ### catchment = pyromb.Catchment(tc, tb, tr) connected = catchment.connect() # Create the traveller and pass the catchment. traveller = pyromb.Traveller(catchment) - + ### Write ### # Control vector to file with a call to the Traveller's getVector method - with open(os.path.join(DIR, '../vector.catg' if isinstance(model, pyromb.RORB) else '../runfile.wbn'), 'w') as f: + output_path = output_name + with open(output_path, "w") as f: f.write(traveller.getVector(model)) - + print(f"Output written to {output_name}") + ### Plot the catchment ###. - if plot: plot_catchment(connected, tr, tc, tb) + if plot: + plot_catchment(connected, tr, tc, tb) + -if (__name__ == "__main__"): - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/src/app_testing.py b/src/app_testing.py new file mode 100644 index 0000000..dbeac23 --- /dev/null +++ b/src/app_testing.py @@ -0,0 +1,67 @@ +# app_testing.py +import os +import pyromb +from plot_catchment import plot_catchment +import shapefile as sf +from shapely.geometry import shape as shapely_shape +import logging + +from app import ( + main, + SFVectorLayer, +) # Import main function and SFVectorLayer from app.py + +# Configure logging (optional: configure in app.py instead) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + +# Define testing paths +TEST_DIR = r"Q:/qgis/" +TEST_REACH_PATH = os.path.join(TEST_DIR, "BC_reaches.shp") +TEST_BASIN_PATH = os.path.join(TEST_DIR, "BC_basins.shp") +TEST_CENTROID_PATH = os.path.join(TEST_DIR, "BC_centroids.shp") +TEST_CONFUL_PATH = os.path.join(TEST_DIR, "BC_confluences.shp") + +TEST_OUTPUT_PATH = r"Q:\qgis" +TEST_OUTPUT_NAME = r"testing_mod_python2.catg" +TEST_OUT = os.path.join(TEST_OUTPUT_PATH, TEST_OUTPUT_NAME) + + +def print_shapefile_fields(shp, name): + fields = shp.fields[1:] # skip DeletionFlag + field_names = [field[0] for field in fields] + print(f"{name} fields: {field_names}") + + +def test_main(): + ### Config ### + plot = False # Set to True if you want the catchment to be plotted + model = pyromb.RORB() + # Select your hydrology model, either pyromb.RORB() or pyromb.WBNM() + + ### Build Catchment Objects ### + # Vector layers with test paths + reach_vector = SFVectorLayer(TEST_REACH_PATH) + basin_vector = SFVectorLayer(TEST_BASIN_PATH) + centroid_vector = SFVectorLayer(TEST_CENTROID_PATH) + confluence_vector = SFVectorLayer(TEST_CONFUL_PATH) + + # Print field names (optional, for debugging) + print_shapefile_fields(reach_vector, "Reach") + print_shapefile_fields(basin_vector, "Basin") + print_shapefile_fields(centroid_vector, "Centroid") + print_shapefile_fields(confluence_vector, "Confluence") + + ### Call the main function with test paths and parameters ### + main( + reach_path=TEST_REACH_PATH, + basin_path=TEST_BASIN_PATH, + centroid_path=TEST_CENTROID_PATH, + confluence_path=TEST_CONFUL_PATH, + output_name=TEST_OUT, + plot=plot, + model=model, + ) + + +if __name__ == "__main__": + test_main() diff --git a/src/pyromb/core/attributes/reach.py b/src/pyromb/core/attributes/reach.py index 6e4ae98..65582d0 100644 --- a/src/pyromb/core/attributes/reach.py +++ b/src/pyromb/core/attributes/reach.py @@ -1,12 +1,14 @@ from ..geometry.line import Line from enum import Enum + class ReachType(Enum): NATURAL = 1 UNLINED = 2 LINED = 3 DROWNED = 4 + class Reach(Line): """A Reach object represents a reach as defined in hydrological models. @@ -20,45 +22,50 @@ class Reach(Line): The slope of the reach in m/m """ - def __init__(self, name: str = "", - vector: list = [], - type: ReachType = ReachType.NATURAL, - slope: float = 0.0): + def __init__( + self, + name: str = "", + vector: list = [], + type: ReachType = ReachType.NATURAL, + slope: float = 0.0, + ): super().__init__(vector) self._name: str = name self._type: ReachType = type self._slope: float = slope self._idx: int = 0 - + def __str__(self) -> str: - return "Name: {}\nLength: {}\nType: {}\nSlope: {}".format(self._name, round(self.length(), 3), self._type, self._slope) - + return "Name: {}\nLength: {}\nType: {}\nSlope: {}".format( + self._name, round(self.length(), 3), self._type, self._slope + ) + @property def name(self) -> str: return self._name - + @name.setter def name(self, name: str) -> None: self._name = name - + @property def type(self) -> ReachType: return self._type - + @type.setter def type(self, type: ReachType) -> None: self._type = type @property def slope(self) -> float: - return self._slope - + return self._slope + @slope.setter def slope(self, slope: float) -> None: self._slope = slope def getPoint(self, direction: str): - """ Returns either the upstream or downstream 'ds' point of the reach. + """Returns either the upstream or downstream 'ds' point of the reach. Parameters ---------- @@ -77,9 +84,18 @@ def getPoint(self, direction: str): If direction is not either 'us' or 'ds' """ - if direction == 'us': + if direction == "us": return self._vector[self._idx] - elif direction == 'ds': + elif direction == "ds": return self._vector[self._end - self._idx] else: - raise KeyError("Node direction not properly defines: \n") \ No newline at end of file + raise KeyError("Node direction not properly defines: \n") + + @property + def id(self) -> str: + """Alias for the 'name' attribute.""" + return self._name + + @id.setter + def id(self, value: str) -> None: + self._name = value diff --git a/src/pyromb/core/geometry/shapefile_validation.py b/src/pyromb/core/geometry/shapefile_validation.py new file mode 100644 index 0000000..f14b70b --- /dev/null +++ b/src/pyromb/core/geometry/shapefile_validation.py @@ -0,0 +1,148 @@ +from matplotlib.ft2font import SFNT +from shapely.geometry import shape +from shapely.validation import explain_validity +import shapefile as sf +import logging + + +def validate_shapefile_geometries(shp: sf.Reader, shapefile_name: str) -> bool: + """ + Validate the geometries of a shapefile. + + Parameters + ---------- + shp : shapefile.Reader + The shapefile reader object. + shapefile_name : str + The name of the shapefile (for logging purposes). + + Returns + ------- + bool + True if all geometries are valid, False otherwise. + """ + validation_passed = True + for idx, shp_rec in enumerate(shp.shapes()): + geom = shape(shp_rec.__geo_interface__) + if not geom.is_valid: + validity_reason = explain_validity(geom) + logging.error( + f"Invalid geometry in {shapefile_name} at Shape ID {idx}: {validity_reason}" + ) + validation_passed = False + + if validation_passed: + logging.info(f"All geometries in {shapefile_name} are valid.") + + return validation_passed + + +import logging +import shapefile as sf +from typing import List, Tuple + + +def validate_shapefile_fields( + shp: sf.Reader, shapefile_name: str, expected_fields: List[Tuple[str, str]] +) -> bool: + """ + Validate the fields of a shapefile against expected field names and types. + Additionally, ensure that required fields contain valid data (not None or empty). + + Args: + shp (sf.Reader): Shapefile reader object. + shapefile_name (str): Name of the shapefile for logging purposes. + expected_fields (List[Tuple[str, str]]): List of tuples containing expected field names and their types. + + Returns: + bool: True if all expected fields are present with correct types and contain valid data, False otherwise. + """ + TYPE_MAPPING = { + "C": "Character", + "N": "Numeric", + "F": "Float", + "L": "Logical", + "D": "Date", + "G": "General", + "M": "Memo", + } + + actual_fields = shp.fields[1:] # Skip DeletionFlag field + actual_field_names = [field[0] for field in actual_fields] + actual_field_types = [field[1] for field in actual_fields] + + logging.info(f"\nValidating fields for {shapefile_name}:") + for name, type_code in zip(actual_field_names, actual_field_types): + type_desc = TYPE_MAPPING.get(type_code, "Unknown") + logging.info(f" Field Name: {name}, Type: {type_code} ({type_desc})") + + validation_passed = True + + # Field Name and Type Validation + for exp_field, exp_type in expected_fields: + if exp_field not in actual_field_names: + logging.error(f"Missing expected field '{exp_field}' in {shapefile_name}.") + validation_passed = False + else: + idx = actual_field_names.index(exp_field) + act_type = actual_field_types[idx] + if act_type != exp_type: + type_desc = TYPE_MAPPING.get(act_type, "Unknown") + logging.error( + f"Type mismatch for field '{exp_field}' in {shapefile_name}: " + f"Expected '{exp_type}' ({TYPE_MAPPING.get(exp_type, 'Unknown')}), " + f"Got '{act_type}' ({type_desc})" + ) + validation_passed = False + + # Data Validation: Check for None or Empty Values in Required Fields + if validation_passed: + logging.info(f"Validating data integrity for fields in {shapefile_name}...") + for record_num, record in enumerate(shp.records(), start=1): + for exp_field, _ in expected_fields: + value = record[exp_field] + if value is None or (isinstance(value, str) and not value.strip()): + logging.error( + f"Empty or None value found in field '{exp_field}' " + f"for record {record_num} in {shapefile_name}." + ) + validation_passed = False + if validation_passed: + logging.info(f"All required fields contain valid data in {shapefile_name}.") + + return validation_passed + + +def validate_confluences_out_field(shp: sf.Reader, shapefile_name: str) -> bool: + """ + Validate that the 'out' field in the Confluences shapefile has exactly one '1' and the rest '0'. + + Args: + shp (sf.Reader): Shapefile reader object. + shapefile_name (str): Name of the shapefile for logging purposes. + + Returns: + bool: True if the validation passes, False otherwise. + """ + out_values = [record["out"] for record in shp.records()] + + count_ones = out_values.count(1) + count_zeros = out_values.count(0) + total_records = len(out_values) + + if count_ones != 1: + logging.error( + f"The 'out' field in {shapefile_name} should have exactly one '1'. Found {count_ones}." + ) + return False + + if count_zeros != (total_records - 1): + logging.error( + f"The 'out' field in {shapefile_name} should have {total_records - 1} '0's. Found {count_zeros}." + ) + return False + + logging.info( + f"'out' field validation passed for {shapefile_name}: 1 '1' and {count_zeros} '0's." + ) + return True diff --git a/src/pyromb/core/gis/builder.py b/src/pyromb/core/gis/builder.py index 5d2e9cb..d7595ea 100644 --- a/src/pyromb/core/gis/builder.py +++ b/src/pyromb/core/gis/builder.py @@ -1,23 +1,26 @@ +# builder.py from ..attributes.basin import Basin from ..attributes.confluence import Confluence from ..attributes.reach import Reach from ..attributes.reach import ReachType -from ..geometry.line import pointVector -from ..geometry.point import Point from ..gis.vector_layer import VectorLayer -from ...math import geometry +import logging -class Builder(): +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +class Builder: """ Build the entities of the catchment. - The Builder is responsible for creating the entities (geometry, attributes) that - the catchment will be built from. Building must take place before the - catchment is connected and traversed. + The Builder is responsible for creating the entities (geometry, attributes) that + the catchment will be built from. Building must take place before the + catchment is connected and traversed. - The objects returned from the Builder are to be passed to the Catcment. + The objects returned from the Builder are to be passed to the Catchment. """ - + def reach(self, reach: VectorLayer) -> list: """Build the reach objects. @@ -29,14 +32,14 @@ def reach(self, reach: VectorLayer) -> list: Returns ------- list - A list of the reache objects. + A list of the reach objects. """ reaches = [] for i in range(len(reach)): s = reach.geometry(i) r = reach.record(i) - reaches.append(Reach(r['id'], s, ReachType(r['t']), r['s'])) + reaches.append(Reach(r["id"], s, ReachType(r["t"]), r["s"])) return reaches def basin(self, centroid: VectorLayer, basin: VectorLayer) -> list: @@ -55,23 +58,43 @@ def basin(self, centroid: VectorLayer, basin: VectorLayer) -> list: A list of the basin objects. """ basins = [] + # Precompute Shapely polygons for all basins + basin_geometries = [basin.shapely_geometry(j) for j in range(len(basin))] + for i in range(len(centroid)): - min = 0 - d = 999 - s = centroid.geometry(i) + centroid_geom = centroid.shapely_geometry(i) + centroid_point = centroid_geom.centroid # Shapely Point object + matching_basins = [] + + # Find all basins that contain the centroid point + for j, basin_geom in enumerate(basin_geometries): + if basin_geom.contains(centroid_point): + matching_basins.append(j) + + if not matching_basins: + logging.warning( + f"Centroid ID {centroid.record(i)['id']} at ({centroid_point.x}, {centroid_point.y}) " + f"is not contained within any basin polygon." + ) + continue # Skip this centroid or handle as needed + + if len(matching_basins) > 1: + logging.error( + f"Centroid ID {centroid.record(i)['id']} at ({centroid_point.x}, {centroid_point.y}) " + f"is contained within multiple basins: {matching_basins}. " + f"Associating with the first matching basin." + ) + + # Associate with the first matching basin + associated_basin_idx = matching_basins[0] + associated_basin_geom = basin_geometries[associated_basin_idx] + # Area in the units of the shapefile's projection + a = associated_basin_geom.area r = centroid.record(i) - p = s[0] - for j in range(len(basin)): - b = basin.geometry(j) - v = b - c = geometry.polygon_centroid(pointVector(v)) - l = geometry.length([Point(p[0], p[1]), c]) - if l < d: - d = l - min = j - a = geometry.polygon_area(pointVector(basin.geometry(min))) - fi = r['fi'] - basins.append(Basin(r['id'], p[0], p[1], (a / 1E6), fi)) + fi = r["fi"] + p = centroid_geom.centroid.coords[0] + basins.append(Basin(r["id"], p[0], p[1], (a / 1e6), fi)) + return basins def confluence(self, confluence: VectorLayer) -> list: @@ -80,7 +103,7 @@ def confluence(self, confluence: VectorLayer) -> list: Parameters ---------- confluence : VectorLayer - The vector layer the confluences are on. + The vector layer the confluences are on. Returns ------- @@ -92,5 +115,5 @@ def confluence(self, confluence: VectorLayer) -> list: s = confluence.geometry(i) p = s[0] r = confluence.record(i) - confluences.append(Confluence(r['id'], p[0], p[1], bool(r['out']))) - return confluences \ No newline at end of file + confluences.append(Confluence(r["id"], p[0], p[1], bool(r["out"]))) + return confluences diff --git a/src/pyromb/resources/expected_fields.json b/src/pyromb/resources/expected_fields.json new file mode 100644 index 0000000..f5d4e20 --- /dev/null +++ b/src/pyromb/resources/expected_fields.json @@ -0,0 +1,37 @@ +{ + "reaches": [ + { + "name": "t", + "type": "N" + }, + { + "name": "s", + "type": "N" + }, + { + "name": "id", + "type": "C" + } + ], + "basins": [], + "centroids": [ + { + "name": "id", + "type": "C" + }, + { + "name": "fi", + "type": "N" + } + ], + "confluences": [ + { + "name": "id", + "type": "C" + }, + { + "name": "out", + "type": "N" + } + ] +} \ No newline at end of file From f9018b7763402e46ae9265f7688713e2768a275a Mon Sep 17 00:00:00 2001 From: Chain Frost Date: Sun, 20 Oct 2024 17:06:43 +0800 Subject: [PATCH 2/6] made the vertices snap to points --- data/reaches.dbf | Bin 1186 -> 1186 bytes data/reaches.shp | Bin 24476 -> 24476 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/data/reaches.dbf b/data/reaches.dbf index 93bc7f8be06884eeba212c7dc1866ae96d1f33ae..a3ef0e2d74b6aaedda950dda3a276d366380eba3 100644 GIT binary patch delta 13 UcmZ3)xrmd6xrR%4Bg;G%02odKP5=M^ delta 13 UcmZ3)xrmd6xr#-0Bg;G%02odKPyhe` diff --git a/data/reaches.shp b/data/reaches.shp index 9ba99d61d14f2e34886f8adb6c0b7970cb898832..d185859215d3f2cda614184b3fc4191d6d570157 100644 GIT binary patch delta 123 zcmbQUpK;E9#tj*w22Up?&Y1Tj*ilHfM|#>fImb8OB@dha40dE>U|`S&Vi17Ifz)oE vA!@+R3KBgu`K*jMgthsl47-vdNHFe|oFkBWU<);3B1jMl;x-G~cPjt@26ZC% delta 123 zcmbQUpK;E9#tj*w1_~QZYUce2cI=Sjcr*2zoa3ACl7~%y20Jn`FfeEXF$loqKx#M7 z5H( Date: Sun, 20 Oct 2024 17:14:48 +0800 Subject: [PATCH 3/6] used to output data objects for checking --- src/serialise.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 src/serialise.py diff --git a/src/serialise.py b/src/serialise.py new file mode 100644 index 0000000..4e362e3 --- /dev/null +++ b/src/serialise.py @@ -0,0 +1,135 @@ +# src/serialise.py +import json +import csv +import logging +import pyromb +from pyromb.core.attributes.basin import Basin +from pyromb.core.attributes.confluence import Confluence +from pyromb.core.attributes.reach import Reach +from pyromb.core.catchment import Catchment + +# Set the default suffix +suffix_item = "_sample_new.json" + + +def serialize_to_json(data, filename: str, suffix: str = suffix_item) -> None: + """ + Serializes data to a JSON file. + + Parameters: + data: The data to serialize (e.g., dictionaries, lists). + filename: The target JSON file name without suffix. + suffix: The suffix to append to the filename. + """ + try: + with open(filename + suffix, "w") as f: + json.dump(data, f, indent=4) + logging.info(f"Serialized data to {filename + suffix}") + except Exception as e: + logging.error(f"Failed to serialize data to {filename + suffix}: {e}") + print(data) + + +def serialize_object(obj, filename: str, suffix: str = suffix_item) -> None: + """ + Serializes a custom object or list of objects to JSON. + Converts each object to a dictionary using helper functions. + + Parameters: + obj: The object or list of objects to serialize. + filename: The target JSON file name without suffix. + suffix: The suffix to append to the filename. + """ + try: + if isinstance(obj, list): + # Determine the type of objects in the list and convert accordingly + if len(obj) == 0: + data = [] + else: + first_item = obj[0] + if isinstance(first_item, Reach): + data = [reach_to_dict(reach) for reach in obj] + elif isinstance(first_item, Confluence): + data = [confluence_to_dict(conf) for conf in obj] + elif isinstance(first_item, Basin): + data = [basin_to_dict(basin) for basin in obj] + else: + data = obj # Fallback: assume serializable + elif hasattr(obj, "to_dict"): + # If the object has a to_dict method, use it + data = obj.to_dict() + elif isinstance(obj, Catchment): + data = catchment_to_dict(obj) + else: + data = obj.__dict__ # Fallback: attempt to use __dict__ + + serialize_to_json(data, filename, suffix) + except Exception as e: + logging.error(f"Failed to serialize object to {filename + suffix}: {e}") + + +def serialize_matrix_to_csv(matrix, filename: str, suffix: str = suffix_item) -> None: + """ + Serializes a matrix (list of lists) to a CSV file. + + Parameters: + matrix: The matrix to serialize. + filename: The target CSV file name without suffix. + suffix: The suffix to append to the filename. + """ + try: + with open(filename + suffix, "w", newline="") as f: + writer = csv.writer(f) + writer.writerows(matrix) + logging.info(f"Serialized matrix to {filename + suffix}") + except Exception as e: + logging.error(f"Failed to serialize matrix to {filename + suffix}: {e}") + + +def reach_to_dict(reach): + """ + Converts a Reach object to a dictionary. + """ + return { + "name": reach.name, + "type": reach.reachType.name, # Convert Enum to string + "slope": reach.slope, + "vector": [{"x": point._x, "y": point._y} for point in reach._vector], + # Accessing the internal _vector attribute + } + + +def confluence_to_dict(confluence): + """ + Converts a Confluence object to a dictionary. + """ + return { + "name": confluence.name, + "isOut": confluence.isOut, # Correct attribute + # Add other relevant attributes here if necessary + } + + +def basin_to_dict(basin): + """ + Converts a Basin object to a dictionary. + """ + return { + "name": basin.name, + "area": basin.area, + "fi": basin.fi, + # Add other relevant attributes here if necessary + } + + +def catchment_to_dict(catchment): + """ + Converts a Catchment object to a dictionary. + """ + return { + "edges": [reach.name for reach in catchment._edges], # Assuming _edges is a list of Reach objects + "vertices": [conf.name for conf in catchment._vertices], # Assuming _vertices is a list of Confluence objects + "incidenceMatrixDS": catchment._incidenceMatrixDS.tolist(), + "incidenceMatrixUS": catchment._incidenceMatrixUS.tolist(), + # Add other relevant attributes here if necessary + } From fe58b88951ca825f1d7f8a4ab9171579da886869 Mon Sep 17 00:00:00 2001 From: Chain Frost Date: Sun, 20 Oct 2024 19:46:04 +0800 Subject: [PATCH 4/6] working with app.py, now to test via QGIS --- .gitignore | 11 + pyproject.toml | 7 +- setup.py | 8 +- src/app.py | 280 ++++++------ src/app_testing.py | 61 ++- src/plot_catchment.py | 17 +- src/pyromb/core/attributes/node.py | 18 +- src/pyromb/core/attributes/reach.py | 57 ++- src/pyromb/core/catchment.py | 407 ++++++++++++----- src/pyromb/core/geometry/line.py | 188 +++++--- src/pyromb/core/geometry/point.py | 66 ++- src/pyromb/core/geometry/polygon.py | 57 ++- .../core/geometry/shapefile_validation.py | 395 +++++++++++++---- src/pyromb/core/gis/builder.py | 363 +++++++++++++--- src/pyromb/core/gis/vector_layer.py | 69 ++- src/pyromb/core/traveller.py | 94 ++-- src/pyromb/math/geometry.py | 410 +++++++++++++++--- src/pyromb/model/rorb.py | 303 +++++++------ src/pyromb/resources/expected_fields.json | 14 +- src/serialise.py | 3 + src/sf_vector_layer.py | 135 ++++++ vector.catg | 281 ------------ 22 files changed, 2142 insertions(+), 1102 deletions(-) create mode 100644 src/sf_vector_layer.py delete mode 100644 vector.catg diff --git a/.gitignore b/.gitignore index b6e4761..55dbe07 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,14 @@ dmypy.json # Pyre type checker .pyre/ + +# others +gpt-input.txt +src/app_testing.py +.vscode/settings.json +test_gdal.py + +#started grabbing some unit tests from o1-mini +tests/ + +src/app_testing.py diff --git a/pyproject.toml b/pyproject.toml index 74012e1..a39aa8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyromb" -version = "0.2.2" +version = "0.2.1" authors = [ { name = "Tom Norman", email = "tom@normcosystems.com" } ] @@ -13,10 +13,7 @@ readme = "README.md" requires-python = ">=3.9" dependencies = [ - "shapely", - "pyshp", - "matplotlib", - "numpy", + "gdal", # Add any additional dependencies here ] diff --git a/setup.py b/setup.py index b362768..9f87093 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import setup, find_packages # type:ignore import os # Read the long description from README.md @@ -9,9 +9,7 @@ setup( name="pyromb", version="0.2.1", - packages=find_packages( - where="src", exclude=["*.tests", "*.tests.*", "tests.*", "tests"] - ), + packages=find_packages(where="src", exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), package_dir={"": "src"}, author="Tom Norman", author_email="tom@normcosystems.com", @@ -24,6 +22,6 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires=">=3.12", + python_requires=">=3.9", include_package_data=True, # Ensures inclusion of files specified in MANIFEST.in ) diff --git a/src/app.py b/src/app.py index a172286..1942db5 100644 --- a/src/app.py +++ b/src/app.py @@ -1,195 +1,157 @@ -# app.py +# src/app.py import os import pyromb from plot_catchment import plot_catchment -import shapefile as sf -from shapely.geometry import shape as shapely_shape +from sf_vector_layer import SFVectorLayer import logging -import json -from typing import Any -from pyromb.core.geometry.shapefile_validation import ( - validate_shapefile_fields, - validate_shapefile_geometries, - validate_confluences_out_field, -) +from typing import Any, Optional +from serialise import serialize_object, serialize_matrix_to_csv # Configure logging -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +# logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +# Define shapefile paths DIR = os.path.dirname(__file__) REACH_PATH = os.path.join(DIR, "../data", "reaches.shp") BASIN_PATH = os.path.join(DIR, "../data", "basins.shp") CENTROID_PATH = os.path.join(DIR, "../data", "centroids.shp") CONFUL_PATH = os.path.join(DIR, "../data", "confluences.shp") -# Load expected fields from JSON file -with open(os.path.join(DIR, r"pyromb\resources", r"expected_fields.json"), "r") as f: - EXPECTED_FIELDS_JSON = json.load(f) - -# Convert JSON to the required dictionary format -EXPECTED_FIELDS = { - key: [(field["name"], field["type"]) for field in fields] - for key, fields in EXPECTED_FIELDS_JSON.items() -} - - -class SFVectorLayer(sf.Reader, pyromb.VectorLayer): - """ - Wrap the shapefile.Reader() with the necessary interface - to work with the builder. - """ - - def __init__(self, path) -> None: - super().__init__(path) - # Extract field names, skipping the first DeletionFlag field - self.field_names = [field[0] for field in self.fields[1:]] - # Precompute Shapely geometries for all shapes - self.shapely_geometries = [ - shapely_shape(self.shape(i).__geo_interface__) for i in range(len(self)) - ] - - def geometry(self, i) -> list: - return self.shape(i).points - - def shapely_geometry(self, i): - """ - Return the Shapely geometry for the ith shape. - """ - return self.shapely_geometries[i] - - def record(self, i) -> dict: - """ - Return a dictionary mapping field names to their corresponding values. - """ - rec = super().record(i) - return dict(zip(self.field_names, rec)) - - def __len__(self) -> int: - return super().__len__() - def main( - reach_path: str | None = None, - basin_path: str | None = None, - centroid_path: str | None = None, - confluence_path: str | None = None, - output_name: str | None = None, + output_name: Optional[str] = None, + reach_path: str = REACH_PATH, + basin_path: str = BASIN_PATH, + centroid_path: str = CENTROID_PATH, + confluence_path: str = CONFUL_PATH, + model: Any = pyromb.RORB(), plot: bool = False, - model: Any | None = None, + serialise_for_testing: bool = False, ) -> None: """ Main function to build and process catchment data. Parameters ---------- - reach_path : str - Path to the reaches shapefile. - basin_path : str - Path to the basins shapefile. - centroid_path : str - Path to the centroids shapefile. - confluence_path : str - Path to the confluences shapefile. - output_name : str - Name of the output file. - plot : bool - Whether to plot the catchment. - model : pyromb.Model - The hydrology model to use. + output_name : Optional[str] + Name of the output file. If not provided, a default name is assigned based on the model. + reach_path : str, optional + Path to the reaches shapefile. Defaults to REACH_PATH. + basin_path : str, optional + Path to the basins shapefile. Defaults to BASIN_PATH. + centroid_path : str, optional + Path to the centroids shapefile. Defaults to CENTROID_PATH. + confluence_path : str, optional + Path to the confluences shapefile. Defaults to CONFUL_PATH. + plot : bool, optional + Whether to plot the catchment. Defaults to False. + model : + The hydrology model to use. Defaults to pyromb.RORB(). """ - # Set default paths if not provided - reach_path = reach_path or REACH_PATH - basin_path = basin_path or BASIN_PATH - centroid_path = centroid_path or CENTROID_PATH - confluence_path = confluence_path or CONFUL_PATH - model = model or pyromb.RORB() + + # Assign default output name based on the model type if isinstance(model, pyromb.RORB): - output_name = output_name or os.path.join(DIR, "../vector.catg") + default_output = os.path.join(DIR, "../vector.catg") else: - output_name = output_name or os.path.join(DIR, "../runfile.wbnm") - model = model or pyromb.RORB() + default_output = os.path.join(DIR, "../runfile.wbnm") + output_name = output_name or default_output ### Build Catchment Objects ### - # Vector layers - reach_vector = SFVectorLayer(reach_path) - basin_vector = SFVectorLayer(basin_path) - centroid_vector = SFVectorLayer(centroid_path) - confluence_vector = SFVectorLayer(confluence_path) - - # Validate shapefile fields - validation_reaches = validate_shapefile_fields( - reach_vector, "Reaches", EXPECTED_FIELDS["reaches"] - ) - validation_basins = validate_shapefile_fields( - basin_vector, "Basins", EXPECTED_FIELDS["basins"] - ) - validation_centroids = validate_shapefile_fields( - centroid_vector, "Centroids", EXPECTED_FIELDS["centroids"] - ) - validation_confluences = validate_shapefile_fields( - confluence_vector, "Confluences", EXPECTED_FIELDS["confluences"] - ) - - validate_confluences_out = validate_confluences_out_field( - confluence_vector, "Confluences" - ) - - # Validate shapefile geometries - validation_geometries_reaches = validate_shapefile_geometries( - reach_vector, "Reaches" - ) - validation_geometries_basins = validate_shapefile_geometries(basin_vector, "Basins") - validation_geometries_centroids = validate_shapefile_geometries( - centroid_vector, "Centroids" - ) - validation_geometries_confluences = validate_shapefile_geometries( - confluence_vector, "Confluences" - ) - - # Decide whether to proceed based on validation - # Decide whether to proceed based on validation - if not all( - [ - validation_reaches, - validation_basins, - validation_centroids, - validation_confluences, - validate_confluences_out, - validation_geometries_reaches, - validation_geometries_basins, - validation_geometries_centroids, - validation_geometries_confluences, - ] - ): - logging.warning( - "One or more shapefiles failed validation. Proceeding with caution." - ) - else: - print("Shapefiles passed initial validation check.") - - # Create the builder. - builder = pyromb.Builder() - # Build each element as per the vector layer. - tr = builder.reach(reach_vector) - tc = builder.confluence(confluence_vector) - tb = builder.basin(centroid_vector, basin_vector) + try: + # Initialize vector layers using OGR + reach_vector = SFVectorLayer(reach_path) + basin_vector = SFVectorLayer(basin_path) + centroid_vector = SFVectorLayer(centroid_path) + confluence_vector = SFVectorLayer(confluence_path) + logging.info("Successfully loaded all shapefile layers.") + except (FileNotFoundError, IndexError, ValueError) as e: + logging.error(f"Failed to load shapefile layers: {e}") + return + except Exception as e: + logging.error(f"An unexpected error occurred while loading shapefiles: {e}") + return + + # Create the builder + try: + builder = pyromb.Builder() + logging.info("Builder initialized successfully.") + except Exception as e: + logging.error(f"Failed to initialize Builder: {e}") + return + + # Build catchment components + try: + tr = builder.reach(reach_vector) + logging.info("Built reaches.") + tc = builder.confluence(confluence_vector) + logging.info("Built confluences.") + tb = builder.basin(centroid_vector, basin_vector) + logging.info("Built basins.") + logging.info("Catchment components built successfully.") + + # Serialize components if flag is set + if serialise_for_testing: + serialize_object(tr, "reach") + # Serializes to "reach_new.json" or "reach_old.json" based on serialise.py copy + serialize_object(tc, "confluence") # Serializes to "confluence_new.json" or "confluence_old.json" + serialize_object(tb, "basin") # Serializes to "basin_new.json" or "basin_old.json" + logging.info("Serialized reaches, confluences, and basins.") + except Exception as e: + logging.error(f"Failed to build catchment components: {e}") + return ### Create the catchment ### - catchment = pyromb.Catchment(tc, tb, tr) - connected = catchment.connect() + try: + catchment = pyromb.Catchment(tc, tb, tr) + connected = catchment.connect() + logging.info("Catchment created and connected successfully.") + + # Serialize catchment if flag is set + if serialise_for_testing: + serialize_object(catchment, "catchment") + # Serializes to "catchment_new.json" or "catchment_old.json" + logging.info("Serialized catchment.") + except Exception as e: + logging.error(f"Failed to create and connect catchment: {e}") + return + + # Connect the catchment and serialize connection matrices + try: + ds_matrix, us_matrix = catchment.connect() + if serialise_for_testing: + serialize_matrix_to_csv(ds_matrix.tolist(), "ds_matrix") + # Serializes to "ds_matrix_new.json" or "ds_matrix_old.json" + serialize_matrix_to_csv(us_matrix.tolist(), "us_matrix") + # Serializes to "us_matrix_new.json" or "us_matrix_old.json" + logging.info("Serialized connection matrices.") + except Exception as e: + logging.error(f"Failed to connect catchment or serialize connection matrices: {e}") + return + # Create the traveller and pass the catchment. - traveller = pyromb.Traveller(catchment) + try: + traveller = pyromb.Traveller(catchment) + logging.info("Traveller created successfully.") + except Exception as e: + logging.error(f"Failed to create Traveller: {e}") + return ### Write ### - # Control vector to file with a call to the Traveller's getVector method - output_path = output_name - with open(output_path, "w") as f: - f.write(traveller.getVector(model)) - print(f"Output written to {output_name}") - - ### Plot the catchment ###. + try: + with open(output_name, "w") as f: + f.write(traveller.getVector(model)) + logging.info(f"Output written to {output_name}") + except Exception as e: + logging.error(f"Failed to write output: {e}") + return + + ### Plot the catchment ### if plot: - plot_catchment(connected, tr, tc, tb) + try: + plot_catchment(connected, tr, tc, tb) + logging.info("Catchment plotted successfully.") + except Exception as e: + logging.error(f"Failed to plot catchment: {e}") if __name__ == "__main__": diff --git a/src/app_testing.py b/src/app_testing.py index dbeac23..b5f5c5a 100644 --- a/src/app_testing.py +++ b/src/app_testing.py @@ -2,39 +2,51 @@ import os import pyromb from plot_catchment import plot_catchment -import shapefile as sf -from shapely.geometry import shape as shapely_shape import logging +from sf_vector_layer import SFVectorLayer +from app import main -from app import ( - main, - SFVectorLayer, -) # Import main function and SFVectorLayer from app.py +# Configure logging +logging.basicConfig( + level=logging.DEBUG, # Capture all levels of logs (DEBUG and above) + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], + # handlers=[logging.FileHandler("debug.log"), logging.StreamHandler()], # Log to a file # Also log to the console +) -# Configure logging (optional: configure in app.py instead) -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") -# Define testing paths -TEST_DIR = r"Q:/qgis/" -TEST_REACH_PATH = os.path.join(TEST_DIR, "BC_reaches.shp") -TEST_BASIN_PATH = os.path.join(TEST_DIR, "BC_basins.shp") -TEST_CENTROID_PATH = os.path.join(TEST_DIR, "BC_centroids.shp") -TEST_CONFUL_PATH = os.path.join(TEST_DIR, "BC_confluences.shp") +# Define shapefile paths +DIR = os.path.dirname(__file__) +TEST_REACH_PATH = os.path.join(DIR, "../data", "reaches.shp") +TEST_BASIN_PATH = os.path.join(DIR, "../data", "basins.shp") +TEST_CENTROID_PATH = os.path.join(DIR, "../data", "centroids.shp") +TEST_CONFUL_PATH = os.path.join(DIR, "../data", "confluences.shp") -TEST_OUTPUT_PATH = r"Q:\qgis" -TEST_OUTPUT_NAME = r"testing_mod_python2.catg" +PARENT_DIR = os.path.dirname(DIR) # This gets the parent folder of the current directory +TEST_OUTPUT_PATH = os.path.join(PARENT_DIR, r"./") +TEST_OUTPUT_NAME = r"testing_ogr_2.catg" TEST_OUT = os.path.join(TEST_OUTPUT_PATH, TEST_OUTPUT_NAME) -def print_shapefile_fields(shp, name): - fields = shp.fields[1:] # skip DeletionFlag - field_names = [field[0] for field in fields] +def print_vector_layer_fields(vector_layer, name): + """ + Print the field names of a vector layer using OGR. + + Parameters: + - vector_layer: An instance of SFVectorLayer. + - name: A string indicating the name of the layer. + """ + # Access the layer definition + layer_defn = vector_layer.layer.GetLayerDefn() + # Get the field names + field_names = [layer_defn.GetFieldDefn(i).GetName() for i in range(layer_defn.GetFieldCount())] print(f"{name} fields: {field_names}") def test_main(): ### Config ### - plot = False # Set to True if you want the catchment to be plotted + plot = True # Set to True if you want the catchment to be plotted + serialise_for_testing = False model = pyromb.RORB() # Select your hydrology model, either pyromb.RORB() or pyromb.WBNM() @@ -46,10 +58,10 @@ def test_main(): confluence_vector = SFVectorLayer(TEST_CONFUL_PATH) # Print field names (optional, for debugging) - print_shapefile_fields(reach_vector, "Reach") - print_shapefile_fields(basin_vector, "Basin") - print_shapefile_fields(centroid_vector, "Centroid") - print_shapefile_fields(confluence_vector, "Confluence") + print_vector_layer_fields(reach_vector, "Reach") + print_vector_layer_fields(basin_vector, "Basin") + print_vector_layer_fields(centroid_vector, "Centroid") + print_vector_layer_fields(confluence_vector, "Confluence") ### Call the main function with test paths and parameters ### main( @@ -60,6 +72,7 @@ def test_main(): output_name=TEST_OUT, plot=plot, model=model, + serialise_for_testing=serialise_for_testing, ) diff --git a/src/plot_catchment.py b/src/plot_catchment.py index 61859df..0f59c45 100644 --- a/src/plot_catchment.py +++ b/src/plot_catchment.py @@ -1,5 +1,10 @@ import matplotlib.pyplot as plt import numpy as np +import logging + +# Configure logging (adjust as needed or configure in a higher-level module) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + def plot_catchment(connected, tr, tc, tb): cx = [] @@ -36,7 +41,7 @@ def plot_catchment(connected, tr, tc, tb): nodeNames.append("{}[{}]".format(n.name, i)) for i, b in enumerate(tb): nodeNames.append("{}[{}]".format(b.name, i + len(cname))) - + fig, ax = plt.subplots() im = ax.imshow(connected[0]) ax.set_xticks(np.arange(0, len(reachNames))) @@ -45,8 +50,8 @@ def plot_catchment(connected, tr, tc, tb): ax.set_yticklabels(nodeNames) for i in range(len(nodeNames)): for j in range(len(reachNames)): - text = ax.text(j, i, connected[0][i, j], ha='center', va='center', color='w') - ax.set_title('Catchment Incidence Matrix (DS)') + text = ax.text(j, i, connected[0][i, j], ha="center", va="center", color="w") + ax.set_title("Catchment Incidence Matrix (DS)") fig.tight_layout() fig, ax = plt.subplots() @@ -57,7 +62,7 @@ def plot_catchment(connected, tr, tc, tb): ax.set_yticklabels(nodeNames) for i in range(len(nodeNames)): for j in range(len(reachNames)): - text = ax.text(j, i, connected[1][i, j], ha='center', va='center', color='w') - ax.set_title('Catchment Incidence Matrix (US)') + text = ax.text(j, i, connected[1][i, j], ha="center", va="center", color="w") + ax.set_title("Catchment Incidence Matrix (US)") fig.tight_layout() - plt.show() \ No newline at end of file + plt.show() diff --git a/src/pyromb/core/attributes/node.py b/src/pyromb/core/attributes/node.py index 8d1659c..cdfb050 100644 --- a/src/pyromb/core/attributes/node.py +++ b/src/pyromb/core/attributes/node.py @@ -1,9 +1,11 @@ +# src/pyromb/core/attributes/node.py from ..geometry.point import Point + class Node(Point): - """Node in the catchment tree. - - Encapsulates attributes of point like features in the catchment such as + """Node in the catchment tree. + + Encapsulates attributes of point-like features in the catchment such as basins and confluences. Attributes @@ -18,8 +20,12 @@ def __init__(self, name: str = "", x: float = 0.0, y: float = 0.0) -> None: @property def name(self) -> str: - return self._name - + return self._name + @name.setter - def name(self, name: str): + def name(self, name: str) -> None: self._name = name + + def coordinates(self) -> tuple[float, float]: + """Return the (x, y) coordinates of the Node.""" + return (self.x, self.y) diff --git a/src/pyromb/core/attributes/reach.py b/src/pyromb/core/attributes/reach.py index 65582d0..2a79d15 100644 --- a/src/pyromb/core/attributes/reach.py +++ b/src/pyromb/core/attributes/reach.py @@ -1,5 +1,9 @@ +# src/pyromb/core/attributes/reach.py +from typing import Optional, Union, cast from ..geometry.line import Line from enum import Enum +from .node import Node +from ..geometry.point import Point class ReachType(Enum): @@ -16,7 +20,7 @@ class Reach(Line): ---------- name : str The name of the reach, should be unique - type : ReachType + reachType : ReachType The type of reach as specified by the hydrological model. slope : float The slope of the reach in m/m @@ -25,20 +29,17 @@ class Reach(Line): def __init__( self, name: str = "", - vector: list = [], - type: ReachType = ReachType.NATURAL, + vector: Optional[list[Node]] = None, + reachType: ReachType = ReachType.NATURAL, slope: float = 0.0, ): - super().__init__(vector) + super().__init__(cast(list[Point], vector) if vector is not None else []) self._name: str = name - self._type: ReachType = type + self._reachType: ReachType = reachType self._slope: float = slope - self._idx: int = 0 def __str__(self) -> str: - return "Name: {}\nLength: {}\nType: {}\nSlope: {}".format( - self._name, round(self.length(), 3), self._type, self._slope - ) + return f"Name: {self._name}\nLength: {round(self.length, 3)}\nType: {self.reachType}\nSlope: {self._slope}" @property def name(self) -> str: @@ -49,12 +50,12 @@ def name(self, name: str) -> None: self._name = name @property - def type(self) -> ReachType: - return self._type + def reachType(self) -> ReachType: + return self._reachType - @type.setter - def type(self, type: ReachType) -> None: - self._type = type + @reachType.setter + def reachType(self, reachType: ReachType) -> None: + self._reachType = reachType @property def slope(self) -> float: @@ -64,18 +65,18 @@ def slope(self) -> float: def slope(self, slope: float) -> None: self._slope = slope - def getPoint(self, direction: str): - """Returns either the upstream or downstream 'ds' point of the reach. + def getPoint(self, direction: str) -> Node: + """Returns either the upstream or downstream 'us'/'ds' point of the reach. Parameters ---------- direction : str - 'us' - for upstream point. \n + 'us' - for upstream point. 'ds' - for downstream point Returns ------- - Point + Node The US or DS point Raises @@ -83,19 +84,17 @@ def getPoint(self, direction: str): KeyError If direction is not either 'us' or 'ds' """ - if direction == "us": - return self._vector[self._idx] + return cast(Node, self._vector[0]) # Assuming the first point is upstream elif direction == "ds": - return self._vector[self._end - self._idx] + return cast(Node, self._vector[-1]) # Assuming the last point is downstream else: - raise KeyError("Node direction not properly defines: \n") + raise KeyError("Node direction not properly defined: expected 'us' or 'ds'.") - @property - def id(self) -> str: - """Alias for the 'name' attribute.""" - return self._name + def getStart(self) -> Node: + """Get the upstream node of the reach.""" + return self.getPoint("us") - @id.setter - def id(self, value: str) -> None: - self._name = value + def getEnd(self) -> Node: + """Get the downstream node of the reach.""" + return self.getPoint("ds") diff --git a/src/pyromb/core/catchment.py b/src/pyromb/core/catchment.py index 2788056..f155e83 100644 --- a/src/pyromb/core/catchment.py +++ b/src/pyromb/core/catchment.py @@ -1,114 +1,327 @@ +# src/pyromb/core/catchment.py + +from typing import Optional, Union +import numpy as np +import logging +from collections import deque from .attributes.basin import Basin from .attributes.confluence import Confluence from .attributes.node import Node from .attributes.reach import Reach -from ..math import geometry -import numpy as np +from ..math.geometry import length +import math + +# Configure logging (adjust as needed or configure in a higher-level module) +# logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + class Catchment: - """The Catchment is a tree of attributes which describes how water - flows through the model and the entities which act upon it. + """ + The Catchment is a tree of attributes that describes how water + flows through the model and the entities which act upon it. Parameters ---------- - confluences : list[Confluence] - The Confluences in the catchment - basins : list[Basin] - The Basins in the catchment - reaches: list[Reach] - The Reaches in the catchment + confluences : Optional[List[Confluence]] + The Confluences in the catchment. + basins : Optional[List[Basin]] + The Basins in the catchment. + reaches : Optional[List[Reach]] + The Reaches in the catchment. """ - def __init__(self, confluences: list = [], basins: list = [], reaches: list = []) -> None: - self._edges: list[Reach] = reaches - self._vertices: list[Node] = confluences + basins - self._incidenceMatrixDS: list = [] - self._incidenceMatrixUS: list = [] - self._out = 0 - self._endSentinel = -1 + def __init__( + self, + confluences: Optional[list[Confluence]] = None, + basins: Optional[list[Basin]] = None, + reaches: Optional[list[Reach]] = None, + ) -> None: + self._edges: list[Reach] = reaches if reaches is not None else [] + self._vertices: list[Node] = [] + if confluences: + self._vertices.extend(confluences) + if basins: + self._vertices.extend(basins) + self._incidenceMatrixDS: np.ndarray = np.array([]) + self._incidenceMatrixUS: np.ndarray = np.array([]) + self._out_node_index: int = -1 + self._endSentinel: int = -1 + + def _initialize_connection_matrix(self, num_vertices: int, num_edges: int) -> np.ndarray: + """ + Initialize and populate the connection matrix based on vertex and edge distances. + + Parameters + ---------- + num_vertices : int + Number of vertices in the catchment. + num_edges : int + Number of edges (reaches) in the catchment. + + Returns + ------- + np.ndarray + The populated connection matrix. + """ + connection_matrix = np.zeros((num_vertices, num_edges), dtype=int) + + for edge_idx, edge in enumerate(self._edges): + start_node = edge.getStart() + end_node = edge.getEnd() + + # Find the closest start and end vertices + closest_start_idx = self._find_closest_vertex(start_node) + closest_end_idx = self._find_closest_vertex(end_node) + + if closest_start_idx == -1 or closest_end_idx == -1: + logging.warning(f"Edge {edge_idx} has no valid start or end vertex.") + continue + + # Update the connection matrix + connection_matrix[closest_start_idx, edge_idx] = 1 # Upstream connection + connection_matrix[closest_end_idx, edge_idx] = 2 # Downstream connection + + return connection_matrix + + def _build_incidence_matrices( + self, + connection_matrix: np.ndarray, + incidence_matrix_ds: np.ndarray, + incidence_matrix_us: np.ndarray, + num_vertices: int, + num_edges: int, + ) -> None: + """ + Build the downstream and upstream incidence matrices using BFS traversal. + + Parameters + ---------- + connection_matrix : np.ndarray + The connection matrix indicating upstream and downstream connections. + incidence_matrix_ds : np.ndarray + The downstream incidence matrix to populate. + incidence_matrix_us : np.ndarray + The upstream incidence matrix to populate. + num_vertices : int + Number of vertices in the catchment. + num_edges : int + Number of edges (reaches) in the catchment. + """ + # Initialize separate BFS queues and color matrices for US and DS + queue_us = deque() + queue_ds = deque() + + # Start BFS for Upstream connections + queue_us.append((self._out_node_index, -1)) # (vertex index, edge index) + colour_us = np.zeros((num_vertices, num_edges), dtype=int) + + while queue_us: + current_vertex, incoming_edge = queue_us.popleft() + + for edge_idx in range(num_edges): + connection_type = connection_matrix[current_vertex, edge_idx] + + if connection_type != 1: + continue # Only process upstream connections + + if colour_us[current_vertex, edge_idx] != 0: + continue # Edge already processed + + # Mark as visited + colour_us[current_vertex, edge_idx] = 1 + + # Determine downstream vertex based on connection type + edge: Reach = self._edges[edge_idx] + downstream_node: Node = edge.getEnd() + + downstream_vertex_coords = downstream_node.coordinates() + downstream_vertex_idx = self._find_vertex_by_coordinates(downstream_vertex_coords) + + if downstream_vertex_idx == -1: + logging.warning( + f"Downstream vertex for edge {edge_idx} not found.\n" + f"Edge Details: Start Node ID {edge.getStart().name}, " + f"End Node ID {edge.getEnd().name}, Edge ID {edge.name}.\n" + f"Downstream Node Coordinates: {downstream_node.coordinates()}" + ) + continue + + # Update incidence matrices + incidence_matrix_us[current_vertex, edge_idx] = downstream_vertex_idx + + # Enqueue the downstream vertex for further processing + queue_us.append((downstream_vertex_idx, edge_idx)) + + # Start BFS for Downstream connections + queue_ds.append((self._out_node_index, -1)) # (vertex index, edge index) + colour_ds = np.zeros((num_vertices, num_edges), dtype=int) + + while queue_ds: + current_vertex, incoming_edge = queue_ds.popleft() + + for edge_idx in range(num_edges): + connection_type = connection_matrix[current_vertex, edge_idx] + + if connection_type != 2: + continue # Only process downstream connections + + if colour_ds[current_vertex, edge_idx] != 0: + continue # Edge already processed + + # Mark as visited + colour_ds[current_vertex, edge_idx] = 1 + + # Determine upstream vertex based on connection type + edge: Reach = self._edges[edge_idx] + upstream_node: Node = edge.getStart() - def connect(self) -> tuple: - """Connect the individual attributes to create the catchment. + upstream_vertex_coords = upstream_node.coordinates() + upstream_vertex_idx = self._find_vertex_by_coordinates(upstream_vertex_coords) + + if upstream_vertex_idx == -1: + logging.warning( + f"Upstream vertex for edge {edge_idx} not found.\n" + f"Edge Details: Start Node ID {edge.getStart().name}, " + f"End Node ID {edge.getEnd().name}, Edge ID {edge.name}.\n" + f"Upstream Node Coordinates: {upstream_node.coordinates()}" + ) + continue + + # Update incidence matrices + incidence_matrix_ds[current_vertex, edge_idx] = upstream_vertex_idx + + # Enqueue the upstream vertex for further processing + queue_ds.append((upstream_vertex_idx, edge_idx)) + + def _find_vertex_by_coordinates(self, coords: Union[list[float], tuple[float, float]], tol=1e-1) -> int: + """ + Find the index of a vertex based on its coordinates with a tolerance. + + Parameters + ---------- + coords : Union[List[float], Tuple[float, float]] + The (x, y) coordinates to match. + tol : float + Tolerance for floating-point comparison. + + Returns + ------- + int + The index of the matching vertex. Returns -1 if not found. + """ + for idx, vertex in enumerate(self._vertices): + vertex_coords = vertex.coordinates() + distance = math.hypot(vertex_coords[0] - coords[0], vertex_coords[1] - coords[1]) + if distance <= tol: + return idx + return -1 + + def connect(self) -> tuple[np.ndarray, np.ndarray]: + """ + Connect the individual attributes to create the catchment. Returns ------- - tuple - (downstream, upstream) incidence matricies of the catchment tree. + tuple[np.ndarray, np.ndarray] + (downstream, upstream) incidence matrices of the catchment tree. """ - - connectionMatrix = np.zeros((len(self._vertices), len(self._edges)), dtype=int) - for i, edge in enumerate(self._edges): - s = edge.getStart() - e = edge.getEnd() - minStart = 999 - minEnd = 999 - closestStart = 0 - closestEnd = 0 - for j, vert in enumerate(self._vertices): - tempStart = geometry.length([vert, s]) - tempEnd = geometry.length([vert, e]) - if tempStart < minStart: - closestStart = j - minStart = tempStart - if tempEnd < minEnd: - closestEnd = j - minEnd = tempEnd - connectionMatrix[closestStart][i] = 1 - connectionMatrix[closestEnd][i] = 2 - - + num_vertices = len(self._vertices) + num_edges = len(self._edges) + + # Initialize the connection matrix + connection_matrix = np.zeros((num_vertices, num_edges), dtype=int) + + for edge_idx, edge in enumerate(self._edges): + start_node = edge.getStart() + end_node = edge.getEnd() + + # Find the closest start and end vertex indices + closest_start_idx = self._find_closest_vertex(start_node) + closest_end_idx = self._find_closest_vertex(end_node) + + if closest_start_idx == -1 or closest_end_idx == -1: + logging.warning(f"Edge {edge_idx} has no valid start or end vertex.") + continue + + # Populate connection matrix + connection_matrix[closest_start_idx, edge_idx] = 1 # Upstream connection + connection_matrix[closest_end_idx, edge_idx] = 2 # Downstream connection + # Find the 'out' node - # Used to determine the starting point of breath first search - # And subsequently the direction of flow - for k, conf in enumerate(self._vertices): - if isinstance(conf, Confluence): - if conf.isOut: - self._out = k - break - - - # Determine incidence matrix relating reaches to nodes and map downstream direction between elements - # Matrix I (m * m - 1) - # m = nodes - # n = reaches - # value of m n = the index of the downstream node - # Think about I as relating upstream nodes (m) to downstream nodes (m n) through reach (n) - # (m n) of -1 indicates no downstream node for relationship m n - newIncidenceDS = np.zeros((len(self._vertices), len(self._edges)), dtype=int) - newIncidenceDS.fill(self._endSentinel) - newIncidenceUS = newIncidenceDS.copy() - queue = [] - colour = np.zeros((len(self._vertices), len(self._edges))) - i = self._out - j = 0 - queue.append((i, j)) - while(len(queue) != 0): - #Move in the n direction - u = queue.pop() - idxi = u[0] - j = u[1] - for k in range(len(connectionMatrix[u[0]])): - idxj = j % len(connectionMatrix[idxi]) - if connectionMatrix[idxi][idxj] > 0: - if colour[idxi][idxj] == 0: - colour[idxi][idxj] = 1 - u = (idxi, idxj) - queue.append(u) - j += 1 - - #Move in the m direction - i = u[0] - idxj = u[1] - for l in range(len(connectionMatrix)): - idxi = i % len(connectionMatrix) - if connectionMatrix[idxi][idxj] > 0: - if colour[idxi][idxj] == 0: - colour[idxi][idxj] = 1 - queue.append((idxi, idxj)) - newIncidenceUS[u[0]][u[1]] = idxi - newIncidenceDS[idxi][idxj] = u[0] - i += 1 - self._incidenceMatrixDS = newIncidenceDS.copy() - self._incidenceMatrixUS = newIncidenceUS.copy() - - return (self._incidenceMatrixDS, self._incidenceMatrixUS) \ No newline at end of file + self._find_out_node() + + # Initialize incidence matrices with sentinel values + incidence_matrix_ds = np.full((num_vertices, num_edges), self._endSentinel, dtype=int) + incidence_matrix_us = np.full((num_vertices, num_edges), self._endSentinel, dtype=int) + + # Populate Upstream Incidence Matrix + for edge_idx, edge in enumerate(self._edges): + start_idx = self._find_closest_vertex(edge.getStart()) + end_idx = self._find_closest_vertex(edge.getEnd()) + + if start_idx != -1 and end_idx != -1: + # Corrected: Downstream Matrix: [start_vertex][edge] = end_vertex + incidence_matrix_ds[start_idx, edge_idx] = end_idx + + # Corrected: Upstream Matrix: [end_vertex][edge] = start_vertex + incidence_matrix_us[end_idx, edge_idx] = start_idx + + # Debugging: Log the matrices + logging.debug("Connection Matrix:") + logging.debug(connection_matrix) + logging.debug("Incidence Matrix US:") + logging.debug(incidence_matrix_us) + logging.debug("Incidence Matrix DS:") + logging.debug(incidence_matrix_ds) + + self._incidenceMatrixDS = incidence_matrix_ds.copy() + self._incidenceMatrixUS = incidence_matrix_us.copy() + + return (self._incidenceMatrixDS, self._incidenceMatrixUS) + + def _find_closest_vertex(self, node: Node, tol: float = 1e-6) -> int: + """ + Find the index of the closest vertex to a given node based on Cartesian distance. + + Parameters + ---------- + node : Node + The node to find the closest vertex for. + tol : float + Tolerance for floating-point comparison. + + Returns + ------- + int + The index of the closest vertex. Returns -1 if no vertex is found. + """ + min_distance = float("inf") + closest_vertex_idx = -1 + + node_coords = node.coordinates() + + for vertex_idx, vertex in enumerate(self._vertices): + vertex_coords = vertex.coordinates() + distance = math.hypot(vertex_coords[0] - node_coords[0], vertex_coords[1] - node_coords[1]) + + if distance < min_distance: + min_distance = distance + closest_vertex_idx = vertex_idx + + if min_distance > tol: + logging.warning(f"No vertex within tolerance for node at {node_coords}. Closest distance: {min_distance}") + return -1 + + return closest_vertex_idx + + def _find_out_node(self) -> None: + """ + Identify and set the 'out' node in the catchment. + """ + for idx, vertex in enumerate(self._vertices): + if isinstance(vertex, Confluence) and vertex.isOut: + self._out_node_index = idx + logging.info(f"'Out' node found at index {idx}.") + return + logging.error("No 'out' node found in confluences.") + raise ValueError("No 'out' node found in confluences.") diff --git a/src/pyromb/core/geometry/line.py b/src/pyromb/core/geometry/line.py index d66dc9e..780eb5f 100644 --- a/src/pyromb/core/geometry/line.py +++ b/src/pyromb/core/geometry/line.py @@ -1,82 +1,90 @@ -from .point import Point +# src/pyromb/core/geometry/line.py +from typing import Optional, Iterator +from osgeo import ogr from ...math import geometry +from .point import Point + + +class GeometryError(Exception): + """Custom exception for geometry-related errors.""" + + pass + -class Line(): +class Line: """An object representing a line shape type. - + + A line is a sequence of points that defines a path. + Attributes ---------- length : float - - Parametersup + The cartesian length of the line. + + Parameters ---------- - vector : list[Points] - The points that make the line. + vector : Optional[List[Point]] + The points that make up the line. """ - def __init__(self, vector:list = []): - super().__init__() - self._vector = pointVector(vector) - self._end = len(self._vector) - 1 - self._length = geometry.length(self.toVector()) + def __init__(self, vector: Optional[list[Point]] = None) -> None: + if vector is None: + vector = [] + self._vector: list[Point] = self.pointVector(vector) + self._length: float = self.calculate_length() - def __iter__(self): - self.n = 0 + def __iter__(self) -> Iterator[Point]: + self._current = 0 return self - + def __next__(self) -> Point: - if self.n <= self._end: - point = self._vector[self.n] - self.n += 1 + if self._current < len(self._vector): + point = self._vector[self._current] + self._current += 1 return point else: raise StopIteration - - def __len__(self): - return self._end - - def __getitem__(self, i): - return self._vector[i] - - def __setitem__(self, i, v:Point): - self._vector[i] = v - - def append(self, point:Point): - """Add an additional point to the line. - Append adds the point to the head of the geometry. + def __len__(self) -> int: + """Return the number of points in the line.""" + return len(self._vector) + + def __getitem__(self, index: int) -> Point: + return self._vector[index] + + def __setitem__(self, index: int, value: Point): + if not isinstance(value, Point): + raise TypeError("Only Point instances can be assigned.") + self._vector[index] = value + self._length = self.calculate_length() + + def append(self, point: Point): + """Add an additional point to the line. Parameters ---------- point : Point The point to add to the line. """ - + if not isinstance(point, Point): + raise TypeError("Only Point instances can be appended.") self._vector.append(point) - self._end += 1 - self._length = geometry.length(self.toVector()) - - def length(self) -> float: - """The cartisian length of the line. - - Returns - ------- - float - The length - """ + self._length = self.calculate_length() + @property + def length(self) -> float: + """float: The cartesian length of the line.""" return self._length - def toVector(self) -> list: - """Convert the line into a vector of points. + def toVector(self) -> list[Point]: + """Convert the line into a list of points. Returns ------- - list - A list of points. + List[Point] + A list of Point objects. """ - - return self._vector + return self._vector.copy() def getStart(self) -> Point: """Get the starting point of the line. @@ -85,36 +93,78 @@ def getStart(self) -> Point: ------- Point The start point. - """ + Raises + ------ + GeometryError + If the line is empty. + """ + if not self._vector: + raise GeometryError("Line is empty. No start point available.") return self._vector[0] - + def getEnd(self) -> Point: """Get the end point of the line. Returns ------- Point - The end point + The end point. + + Raises + ------ + GeometryError + If the line is empty. """ - - return self._vector[self._end] + if not self._vector: + raise GeometryError("Line is empty. No end point available.") + return self._vector[-1] -def pointVector(vector:list) -> list: - """Convert a list of x,y co-ordinates into a list of Points + def calculate_length(self) -> float: + """Calculate the cartesian length of the line. - Parameters - ---------- - vector : list - A list of (x,y) co-ordinate tuple as floats. + Returns + ------- + float + The length of the line. - Returns - ------- - list - A list of (x,y) co-odinate tuple as points. - """ + Raises + ------ + GeometryError + If the line cannot be converted to an OGR LineString. + """ + if not self._vector: + return 0.0 + coords = [point.coordinates() for point in self._vector] + try: + ogr_line = geometry.create_line_string(coords) + length = geometry.calculate_length(ogr_line) + return length + except geometry.GeometryError as ge: + raise GeometryError(f"Failed to calculate length: {ge}") + + @staticmethod + def pointVector(vector: list[Point]) -> list[Point]: + """Convert a list of Points ensuring all elements are Point instances. + + Parameters + ---------- + vector : List[Point] + A list of Point objects. - points = [] - for t in vector: - points.append(Point(t[0], t[1])) - return points \ No newline at end of file + Returns + ------- + List[Point] + A validated list of Point objects. + + Raises + ------ + TypeError + If any element in the vector is not a Point instance. + """ + if not isinstance(vector, list): + raise TypeError("Vector must be a list of Point instances.") + for idx, item in enumerate(vector): + if not isinstance(item, Point): + raise TypeError(f"Item at index {idx} is not a Point instance.") + return vector.copy() diff --git a/src/pyromb/core/geometry/point.py b/src/pyromb/core/geometry/point.py index 2273bfe..e3367b0 100644 --- a/src/pyromb/core/geometry/point.py +++ b/src/pyromb/core/geometry/point.py @@ -1,28 +1,64 @@ +# src/pyromb/core/geometry/point.py + + class Point: """An object representing a point shape type. - + Parameters ---------- x : float - The x co-ordinate + The x coordinate. y : float - The y co-ordinate + The y coordinate. """ def __init__(self, x: float = 0.0, y: float = 0.0): - self._x = x - self._y = y - - def __str__(self): - return "[{}, {}]".format(self._x, self._y) - - def coordinates(self) -> tuple: - """The co-ordinates of the point. + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + raise ValueError(f"Invalid coordinates: x={x}, y={y}. Both must be numbers.") + self._x = float(x) + self._y = float(y) + + @property + def x(self) -> float: + """float: The x coordinate.""" + return self._x + + @property + def y(self) -> float: + """float: The y coordinate.""" + return self._y + + def coordinates(self) -> tuple[float, float]: + """Get the coordinates of the point. + + Returns + ------- + Tuple[float, float] + The (x, y) coordinates. + """ + return (self._x, self._y) + + def __str__(self) -> str: + """Return a string representation of the point.""" + return f"[{self._x}, {self._y}]" + + def __repr__(self) -> str: + """Return an unambiguous string representation of the point.""" + return f"Point(x={self._x}, y={self._y})" + + def __eq__(self, other) -> bool: + """Check if two points are equal based on their coordinates. + + Parameters + ---------- + other : Point + The other point to compare. Returns ------- - tuple - (x,y) co-ordinates. + bool + True if both points have the same coordinates, False otherwise. """ - - return (self._x, self._y) + if not isinstance(other, Point): + return NotImplemented + return self._x == other._x and self._y == other._y diff --git a/src/pyromb/core/geometry/polygon.py b/src/pyromb/core/geometry/polygon.py index 04e6927..45b6466 100644 --- a/src/pyromb/core/geometry/polygon.py +++ b/src/pyromb/core/geometry/polygon.py @@ -1,35 +1,60 @@ +# src/pyromb/core/geometry/polygon.py +from typing import List, Optional +from osgeo import ogr +from ...math import geometry from .line import Line -from math import geometry from .point import Point + class Polygon(Line): - """An object representing a polyline shape type. + """An object representing a polygon shape type. - A polyline is a closed line which makes an area. + A polygon is a closed shape that defines an area. Attributes ---------- area : float - The cartesian area of the polygon + The cartesian area of the polygon. centroid : Point - The centroid of the polygon - + The centroid of the polygon. + Parameters ---------- - vector : list[Points] - The points which form the polygon + vector : Optional[List[Point]] + The points which form the polygon. """ - - def __init__(self, vector:list = []): + + def __init__(self, vector: Optional[list[Point]] = None) -> None: + if vector is None: + vector = [] super().__init__(vector) - self.append(self[0]) - self._area = geometry.polygon_area(self.toVector()) - self._centroid = geometry.polygon_centroid(self.toVector()) - + + if not self: + raise ValueError("Vector cannot be empty to form a Polygon.") + + # Ensure the polygon is closed by appending the first point at the end if necessary + if self[0] != self[-1]: + self.append(self[0]) + + # Convert the list of Points to a list of (x, y) tuples + coords = [point.coordinates() for point in self.toVector()] + + try: + # Create an OGR Polygon geometry + ogr_polygon = geometry.create_polygon(coords) + + # Calculate area and centroid using the updated geometry module + self._area = geometry.calculate_area(ogr_polygon) + self._centroid = geometry.calculate_centroid(ogr_polygon) + except geometry.GeometryError as ge: + raise ValueError(f"Failed to initialize Polygon: {ge}") + @property def area(self) -> float: + """float: The cartesian area of the polygon.""" return self._area - + @property def centroid(self) -> Point: - return self._centroid \ No newline at end of file + """Point: The centroid of the polygon.""" + return self._centroid diff --git a/src/pyromb/core/geometry/shapefile_validation.py b/src/pyromb/core/geometry/shapefile_validation.py index f14b70b..281e117 100644 --- a/src/pyromb/core/geometry/shapefile_validation.py +++ b/src/pyromb/core/geometry/shapefile_validation.py @@ -1,148 +1,365 @@ -from matplotlib.ft2font import SFNT -from shapely.geometry import shape -from shapely.validation import explain_validity -import shapefile as sf +# src\pyromb\core\geometry\shapefile_validation.py import logging +from osgeo import ogr # type: ignore +from ..gis.vector_layer import VectorLayer +from ...math import geometry +from osgeo import ogr -def validate_shapefile_geometries(shp: sf.Reader, shapefile_name: str) -> bool: + +def validate_shapefile_geometries(vector_layer: VectorLayer, layer_type: str) -> bool: """ - Validate the geometries of a shapefile. + Validate the geometries of a vector layer based on the expected geometry type for the layer. Parameters ---------- - shp : shapefile.Reader - The shapefile reader object. - shapefile_name : str - The name of the shapefile (for logging purposes). + vector_layer : VectorLayer + The vector layer to validate. + layer_type : str + The type of layer (e.g., 'reaches', 'basins', 'confluences', 'centroids'). Returns ------- bool - True if all geometries are valid, False otherwise. + True if all geometries are valid and match the expected type, False otherwise. """ validation_passed = True - for idx, shp_rec in enumerate(shp.shapes()): - geom = shape(shp_rec.__geo_interface__) - if not geom.is_valid: - validity_reason = explain_validity(geom) + layer_name = layer_type.capitalize() + + # Define expected geometry types for each layer type + expected_geometry_types = { + "reaches": ogr.wkbLineString, + "basins": ogr.wkbPolygon, + "confluences": ogr.wkbPoint, + "centroids": ogr.wkbPoint, + } + + # Get the expected geometry type for the layer + expected_geom_type = expected_geometry_types.get(layer_type.lower()) + + if expected_geom_type is None: + logging.error(f"No expected geometry type defined for layer '{layer_type}'.") + return False + + logging.info( + f"Validating geometries for {layer_name} layer. Expected geometry type: {ogr.GeometryTypeToName(expected_geom_type)}" + ) + + for i in range(len(vector_layer)): + # Get the geometry from the vector layer + ogr_geom = vector_layer.get_ogr_geometry(i) + if ogr_geom is None or ogr_geom.IsEmpty(): + logging.error(f"Feature at index {i} has an empty geometry.") + validation_passed = False + continue + + # Check if geometry is valid + if not ogr_geom.IsValid(): + logging.error(f"Feature at index {i} has an invalid geometry.") + validation_passed = False + continue + + # Get the actual geometry type + actual_geom_type = ogr_geom.GetGeometryType() + + # Handle multi-geometries + if actual_geom_type in ( + ogr.wkbMultiPoint, + ogr.wkbMultiLineString, + ogr.wkbMultiPolygon, + ogr.wkbMultiPoint25D, + ogr.wkbMultiLineString25D, + ogr.wkbMultiPolygon25D, + ): + num_geoms = ogr_geom.GetGeometryCount() + if num_geoms == 1: + logging.warning(f"Feature at index {i} is a multi-geometry with a single geometry. Proceeding.") + # Extract the single geometry + ogr_geom = ogr_geom.GetGeometryRef(0) + actual_geom_type = ogr_geom.GetGeometryType() + else: + logging.error( + f"Feature at index {i} is a multi-geometry with {num_geoms} geometries. Expected only one." + ) + validation_passed = False + continue + + # Check if actual geometry type is acceptable, ignoring 2D vs 3D + if not are_geometry_types_equivalent(expected_geom_type, actual_geom_type): + expected_geom_name = ogr.GeometryTypeToName(expected_geom_type) + actual_geom_name = ogr.GeometryTypeToName(actual_geom_type) logging.error( - f"Invalid geometry in {shapefile_name} at Shape ID {idx}: {validity_reason}" + f"Feature at index {i} has geometry type '{actual_geom_name}', " f"but expected '{expected_geom_name}'." ) validation_passed = False + continue if validation_passed: - logging.info(f"All geometries in {shapefile_name} are valid.") + logging.info(f"All geometries in {layer_name} layer are valid and match the expected type.") + else: + logging.error(f"Geometry validation failed for {layer_name} layer.") return validation_passed -import logging -import shapefile as sf -from typing import List, Tuple +def are_geometry_types_equivalent(expected_geom_type, actual_geom_type): + """ + Determine if the actual geometry type is acceptable for the expected geometry type, + ignoring the 2D vs 3D distinction. + """ + # Remove the 3D flag if present + expected_geom_type_2d = expected_geom_type & (~ogr.wkb25DBit) + actual_geom_type_2d = actual_geom_type & (~ogr.wkb25DBit) + + # Map multi-geometry types to their single counterparts + multi_to_single = { + ogr.wkbMultiPoint: ogr.wkbPoint, + ogr.wkbMultiLineString: ogr.wkbLineString, + ogr.wkbMultiPolygon: ogr.wkbPolygon, + } + + # If actual is a multi-geometry, map it to the single geometry type + if actual_geom_type_2d in multi_to_single: + # For multi-geometries, check if expected type matches the single type + actual_geom_type_2d = multi_to_single[actual_geom_type_2d] + + # Similarly, if expected is a multi-geometry, map it to single type + if expected_geom_type_2d in multi_to_single: + expected_geom_type_2d = multi_to_single[expected_geom_type_2d] + + return expected_geom_type_2d == actual_geom_type_2d def validate_shapefile_fields( - shp: sf.Reader, shapefile_name: str, expected_fields: List[Tuple[str, str]] + vector_layer: VectorLayer, shapefile_name: str, expected_fields: list[dict[str, str]] ) -> bool: """ - Validate the fields of a shapefile against expected field names and types. - Additionally, ensure that required fields contain valid data (not None or empty). + Validate the fields of a vector layer against expected field names and types. - Args: - shp (sf.Reader): Shapefile reader object. - shapefile_name (str): Name of the shapefile for logging purposes. - expected_fields (List[Tuple[str, str]]): List of tuples containing expected field names and their types. + Parameters + ---------- + vector_layer : VectorLayer + The vector layer to validate. + shapefile_name : str + The name of the shapefile for logging. + expected_fields : list[dict[str, str]] + A list of expected field definitions with name and type. - Returns: - bool: True if all expected fields are present with correct types and contain valid data, False otherwise. + Returns + ------- + bool + True if all expected fields are present and valid, False otherwise. """ - TYPE_MAPPING = { - "C": "Character", - "N": "Numeric", - "F": "Float", - "L": "Logical", - "D": "Date", - "G": "General", - "M": "Memo", - } + validation_passed = True - actual_fields = shp.fields[1:] # Skip DeletionFlag field - actual_field_names = [field[0] for field in actual_fields] - actual_field_types = [field[1] for field in actual_fields] + logging.info(f"Starting field validation for {shapefile_name} layer.") - logging.info(f"\nValidating fields for {shapefile_name}:") - for name, type_code in zip(actual_field_names, actual_field_types): - type_desc = TYPE_MAPPING.get(type_code, "Unknown") - logging.info(f" Field Name: {name}, Type: {type_code} ({type_desc})") + # Get the actual fields and their types from the vector layer + actual_fields = vector_layer.get_fields() # Call the method to get field names and types - validation_passed = True + # Normalize actual field names to lowercase for case-insensitive comparison + actual_field_names = [field[0].lower() for field in actual_fields] + actual_field_types = {field[0].lower(): field[1] for field in actual_fields} + + logging.info(f"Expected fields for {shapefile_name}:") + for exp_field in expected_fields: + logging.info(f" Field Name: {exp_field['name']}, Expected Type: {exp_field['type']}") - # Field Name and Type Validation - for exp_field, exp_type in expected_fields: - if exp_field not in actual_field_names: - logging.error(f"Missing expected field '{exp_field}' in {shapefile_name}.") + logging.info(f"Actual fields in {shapefile_name}:") + for name, type_code in actual_fields: + type_name = ogr.GetFieldTypeName(type_code) + logging.info(f" Field Name: {name}, Type: {type_name}") + + # Define acceptable type mappings + type_mappings = { + "Integer": ["Integer", "Integer64"], + "Integer64": ["Integer", "Integer64"], + "Real": ["Real"], + "String": ["String"], + } + + # Validate field names and types + for exp_field in expected_fields: + exp_name = exp_field["name"].lower() + exp_type = exp_field["type"] + + if exp_name not in actual_field_names: + logging.error(f"Missing expected field '{exp_field['name']}' in {shapefile_name}.") validation_passed = False else: - idx = actual_field_names.index(exp_field) - act_type = actual_field_types[idx] - if act_type != exp_type: - type_desc = TYPE_MAPPING.get(act_type, "Unknown") + actual_type_code = actual_field_types[exp_name] + actual_type_name = ogr.GetFieldTypeName(actual_type_code) + # Get acceptable types for expected type + acceptable_types = type_mappings.get(exp_type, [exp_type]) + if actual_type_name not in acceptable_types: logging.error( - f"Type mismatch for field '{exp_field}' in {shapefile_name}: " - f"Expected '{exp_type}' ({TYPE_MAPPING.get(exp_type, 'Unknown')}), " - f"Got '{act_type}' ({type_desc})" + f"Type mismatch for field '{exp_field['name']}' in {shapefile_name}: " + f"Expected '{exp_type}', Got '{actual_type_name}'" ) validation_passed = False + else: + logging.info(f"Field '{exp_field['name']}' in {shapefile_name} matches expected type '{exp_type}'.") - # Data Validation: Check for None or Empty Values in Required Fields + # Data integrity validation (e.g., missing or empty values) if validation_passed: - logging.info(f"Validating data integrity for fields in {shapefile_name}...") - for record_num, record in enumerate(shp.records(), start=1): - for exp_field, _ in expected_fields: - value = record[exp_field] - if value is None or (isinstance(value, str) and not value.strip()): - logging.error( - f"Empty or None value found in field '{exp_field}' " - f"for record {record_num} in {shapefile_name}." - ) - validation_passed = False + logging.info(f"Field names and types validated successfully for {shapefile_name}.") + validation_passed = validate_field_data_integrity(vector_layer, shapefile_name, expected_fields) if validation_passed: - logging.info(f"All required fields contain valid data in {shapefile_name}.") + logging.info(f"Field data integrity validated successfully for {shapefile_name}.") + else: + logging.error(f"Field data integrity validation failed for {shapefile_name}.") + else: + logging.error(f"Field names or types validation failed for {shapefile_name}.") return validation_passed -def validate_confluences_out_field(shp: sf.Reader, shapefile_name: str) -> bool: +def validate_field_names_and_types( + shapefile_name: str, expected_fields: list[dict[str, str]], actual_fields: list[tuple[str, int]] +) -> bool: """ - Validate that the 'out' field in the Confluences shapefile has exactly one '1' and the rest '0'. + Validate the field names and types against the expected fields. - Args: - shp (sf.Reader): Shapefile reader object. - shapefile_name (str): Name of the shapefile for logging purposes. + Parameters + ---------- + shapefile_name : str + The name of the shapefile for logging purposes. + expected_fields : list of dict + List of expected field names and types. + actual_fields : list of tuple + List of (field_name, field_type_code) from the vector layer. - Returns: - bool: True if the validation passes, False otherwise. + Returns + ------- + bool + True if all expected fields are present with correct types, False otherwise. """ - out_values = [record["out"] for record in shp.records()] + actual_field_names = [field[0] for field in actual_fields] + actual_field_types = {field[0]: field[1] for field in actual_fields} + + validation_passed = True + + for exp_field in expected_fields: + exp_name = exp_field["name"] + exp_type = exp_field["type"] + + if exp_name not in actual_field_names: + logging.error(f"Missing expected field '{exp_name}' in {shapefile_name}.") + validation_passed = False + else: + actual_type_code = actual_field_types[exp_name] + actual_type_name = ogr.GetFieldTypeName(actual_type_code) + + if actual_type_name != exp_type: + logging.error( + f"Type mismatch for field '{exp_name}' in {shapefile_name}: " + f"Expected '{exp_type}', Got '{actual_type_name}'" + ) + validation_passed = False + + return validation_passed + + +def validate_field_data_integrity( + vector_layer: VectorLayer, shapefile_name: str, expected_fields: list[dict[str, str]] +) -> bool: + """ + Validate the data integrity of required fields in the vector layer. + + Parameters + ---------- + vector_layer : VectorLayer + The vector layer to validate. + shapefile_name : str + The name of the shapefile for logging purposes. + expected_fields : list of dict + List of expected field names and types. + + Returns + ------- + bool + True if all required fields contain valid data, False otherwise. + """ + validation_passed = True + logging.info(f"Validating data integrity for fields in {shapefile_name}...") + + for i in range(len(vector_layer)): + fid = i # Assume vector layer uses index as feature ID + for exp_field in expected_fields: + exp_name = exp_field["name"] + value = vector_layer.get_attributes(i).get(exp_name) + + if value is None: + logging.error(f"None value found in field '{exp_name}' for Feature ID {fid} in {shapefile_name}.") + validation_passed = False + elif isinstance(value, str) and not value.strip(): + logging.error(f"Empty string found in field '{exp_name}' for Feature ID {fid} in {shapefile_name}.") + validation_passed = False + + return validation_passed + + +def validate_confluences_out_field(vector_layer: VectorLayer, shapefile_name: str) -> bool: + """ + Validate that the 'out' field in the confluences vector layer has exactly one '1' and the rest '0'. + + Parameters + ---------- + vector_layer : VectorLayer + The vector layer to validate. + shapefile_name : str + The name of the shapefile for logging purposes. + + Returns + ------- + bool: True if the validation passes, False otherwise. + """ + validation_passed = True + + out_values = [vector_layer.get_attributes(i).get("out") for i in range(len(vector_layer))] count_ones = out_values.count(1) count_zeros = out_values.count(0) total_records = len(out_values) if count_ones != 1: - logging.error( - f"The 'out' field in {shapefile_name} should have exactly one '1'. Found {count_ones}." - ) - return False + logging.error(f"The 'out' field in {shapefile_name} should have exactly one '1'. Found {count_ones}.") + validation_passed = False if count_zeros != (total_records - 1): - logging.error( - f"The 'out' field in {shapefile_name} should have {total_records - 1} '0's. Found {count_zeros}." - ) - return False + logging.error(f"The 'out' field in {shapefile_name} should have {total_records - 1} '0's. Found {count_zeros}.") + validation_passed = False - logging.info( - f"'out' field validation passed for {shapefile_name}: 1 '1' and {count_zeros} '0's." - ) - return True + if validation_passed: + logging.info(f"'out' field validation passed for {shapefile_name}: 1 '1' and {count_zeros} '0's.") + else: + logging.error(f"'out' field validation failed for {shapefile_name}.") + + return validation_passed + + +def load_points(vector_layer: VectorLayer) -> list[ogr.Geometry]: + """ + Load point geometries from a vector layer. + + Parameters + ---------- + vector_layer : VectorLayer + The vector layer to extract point geometries from. + + Returns + ------- + List[ogr.Geometry] + List of point geometries. + """ + points = [] + for i in range(len(vector_layer)): + geom_coords = vector_layer.get_geometry(i) + if len(geom_coords) == 1: # Assuming point geometries have one coordinate + x, y = geom_coords[0] + point = ogr.Geometry(ogr.wkbPoint) + point.AddPoint(x, y) + points.append(point) + else: + logging.error(f"Invalid point geometry at index {i} in vector layer.") + return points diff --git a/src/pyromb/core/gis/builder.py b/src/pyromb/core/gis/builder.py index d7595ea..9fa2a6c 100644 --- a/src/pyromb/core/gis/builder.py +++ b/src/pyromb/core/gis/builder.py @@ -1,13 +1,37 @@ -# builder.py +# src/pyromb/core/geometry/builder.py +import os +import json +import logging +from typing import Optional +from osgeo import ogr # type:ignore +import importlib.resources + from ..attributes.basin import Basin from ..attributes.confluence import Confluence from ..attributes.reach import Reach from ..attributes.reach import ReachType from ..gis.vector_layer import VectorLayer -import logging +from ..attributes.node import Node + +# Import validation functions +from pyromb.core.geometry.shapefile_validation import ( + validate_shapefile_fields, + validate_shapefile_geometries, + validate_confluences_out_field, +) + +# Import geometry functions from math.geometry +from pyromb.math.geometry import ( + create_line_string, + create_polygon, + create_point, + is_geometry_empty, + calculate_area, + contains, +) # Configure logging -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +# logging.basicConfig(level=logging.DEBUG, format="%(levelname)s: %(message)s") class Builder: @@ -21,99 +45,310 @@ class Builder: The objects returned from the Builder are to be passed to the Catchment. """ - def reach(self, reach: VectorLayer) -> list: - """Build the reach objects. + def __init__(self, expected_fields_path: Optional[str] = None): + """ + Initialize the Builder instance by loading expected fields from a JSON file. + """ + # Define the directory where expected_fields.json is located + if expected_fields_path is None: + DIR = os.path.dirname(os.path.abspath(__file__)) + expected_fields_json_path = os.path.join(DIR, "resources", "expected_fields.json") + with importlib.resources.open_text("pyromb.resources", "expected_fields.json") as f: + EXPECTED_FIELDS_JSON = json.load(f) + else: + expected_fields_json_path = expected_fields_path + + # Load expected fields from JSON file + try: + with importlib.resources.open_text("pyromb.resources", "expected_fields.json") as f: + EXPECTED_FIELDS_JSON = json.load(f) + logging.info(f"Loaded expected fields from {expected_fields_json_path}.") + except FileNotFoundError: + logging.critical(f"Expected fields JSON file not found at {expected_fields_json_path}.") + raise FileNotFoundError(f"Expected fields JSON file not found at {expected_fields_json_path}.") + except json.JSONDecodeError as e: + logging.critical(f"Error decoding JSON from {expected_fields_json_path}: {e}") + raise json.JSONDecodeError( + f"Error decoding JSON from {expected_fields_json_path}: {e}", + e.doc, + e.pos, + ) + + # Convert JSON to the required dictionary format + self.expected_fields: dict[str, list[dict[str, str]]] = { + key.lower(): [{"name": field["name"], "type": field["type"]} for field in fields] + for key, fields in EXPECTED_FIELDS_JSON.items() + } + + # Initialize basin geometries storage + self.basin_geometries: Optional[list] = None + + def _validate_vector_layer( + self, + vector_layer: VectorLayer, + layer_type: str, + specific_validations: Optional[list] = None, + ) -> bool: + """ + Validate a vector layer's fields and geometries. Parameters ---------- - reach : VectorLayer - The vector layer which the reaches are in. + vector_layer : VectorLayer + The vector layer to validate. + layer_type : str + The type of layer (e.g., 'reaches', 'basins'). + specific_validations : Optional[List], optional + Additional specific validation functions to apply, by default None. Returns ------- - list - A list of the reach objects. + bool + True if all validations pass, False otherwise. + """ + shapefile_name = layer_type.capitalize() + logging.info(f"Starting validation for {shapefile_name} layer.") + + # Retrieve expected fields for the layer type + expected_fields = self.expected_fields.get(layer_type.lower()) + logging.info(f"Expected fields for {shapefile_name}: {expected_fields}") + + if expected_fields is None: + logging.error(f"No expected fields defined for layer type '{layer_type}'.") + return False + + # If expected_fields is empty, skip field validation + if expected_fields: + logging.info(f"Performing field validation for {shapefile_name}.") + # Validate shapefile fields + fields_valid = validate_shapefile_fields( + vector_layer=vector_layer, + shapefile_name=shapefile_name, + expected_fields=expected_fields, + ) + if not fields_valid: + logging.error(f"Field validation failed for {shapefile_name}.") + return False + else: + logging.info(f"Field validation passed for {shapefile_name}.") + else: + logging.info(f"No expected fields specified for {shapefile_name}, skipping field validation.") + + # Validate shapefile geometries + logging.info(f"Performing geometry validation for {shapefile_name}.") + geometries_valid = validate_shapefile_geometries(vector_layer, layer_type) + if not geometries_valid: + logging.error(f"Geometry validation failed for {shapefile_name}.") + return False + else: + logging.info(f"Geometry validation passed for {shapefile_name}.") + + # Perform any specific validations if provided + if specific_validations: + logging.info(f"Performing specific validations for {shapefile_name}.") + for validation_func in specific_validations: + logging.info(f"Running validation function '{validation_func.__name__}' for {shapefile_name}.") + valid = validation_func(vector_layer, shapefile_name) + if not valid: + logging.error(f"Specific validation '{validation_func.__name__}' failed for {shapefile_name}.") + return False + else: + logging.info(f"Specific validation '{validation_func.__name__}' passed for {shapefile_name}.") + else: + logging.info(f"No specific validations provided for {shapefile_name}.") + + logging.info(f"Validation passed for {shapefile_name} layer.") + return True + + def reach(self, reach_layer: VectorLayer) -> list[Reach]: + """ + Build the reach objects. """ + logging.info("Starting to build Reach objects.") + + # Validate the reach vector layer + if not self._validate_vector_layer(reach_layer, "reaches"): + logging.error("Reach vector layer validation failed.") + raise ValueError("Reach vector layer validation failed.") + else: + logging.info("Reach vector layer validation passed.") reaches = [] - for i in range(len(reach)): - s = reach.geometry(i) - r = reach.record(i) - reaches.append(Reach(r["id"], s, ReachType(r["t"]), r["s"])) - return reaches + num_features = len(reach_layer) + logging.info(f"Number of features in reach layer: {num_features}") - def basin(self, centroid: VectorLayer, basin: VectorLayer) -> list: - """Build the basin objects. + for i in range(num_features): + # logging.info(f"Processing reach feature index {i}") + try: + # Get standardized geometry and attributes + geometry_coords = reach_layer.get_geometry(i) # List of (x, y) tuples + attributes = reach_layer.get_attributes(i) + # logging.info(f"Geometry coordinates for feature {i}: {geometry_coords}") + # logging.info(f"Attributes for feature {i}: {attributes}") - Parameters - ---------- - centroid : VectorLayer - The vector layer which the centroids are in. - basin : VectorLayer - The vector layer which the basins are in. + # Create OGR LineString geometry using the geometry function + ogr_geom = create_line_string(geometry_coords) + # logging.info(f"OGR geometry created for feature {i}") - Returns - ------- - list - A list of the basin objects. + if is_geometry_empty(ogr_geom): + logging.warning(f"Empty geometry at Reaches index {i}. Skipping.") + continue + + # Extract required attributes + reach_id = attributes["id"] + reach_type_value = attributes["t"] + reach_s = attributes["s"] + # logging.info( + # f"Extracted attributes for feature {i} - ID: {reach_id}, Type: {reach_type_value}, Slope: {reach_s}" + # ) + + # Convert the reach type to the ReachType enum + reach_type = ReachType(reach_type_value) + # logging.info(f"ReachType enum for feature {i}: {reach_type}") + + # Convert list of tuples to list of Node instances + node_vector = [Node(x=x, y=y) for x, y in geometry_coords] + # logging.info(f"Node vector for feature {i}: {node_vector}") + + # Create the Reach object + reach = Reach( + name=reach_id, + vector=node_vector, # Pass List[Node] instead of List[Tuple[float, float]] + reachType=reach_type, + slope=reach_s, + ) + reaches.append(reach) + # logging.info(f"Successfully created Reach object for feature {i} with ID '{reach_id}'") + + # except KeyError as e: + # logging.error( + # f"Missing expected field {e} in Reaches record {i}. Available attributes: {attributes.keys()}" + # ) + # raise + except ValueError as e: + logging.error(f"Value error processing Reaches record {i}: {e}") + raise + except Exception as e: + logging.error(f"Unexpected error processing Reaches record {i}: {e}") + raise + + logging.info(f"Successfully built {len(reaches)} Reach objects.") + return reaches + + def basin(self, centroid_layer: VectorLayer, basin_layer: VectorLayer) -> list[Basin]: """ + Build the basin objects using GDAL/OGR for geometry operations. + """ + # Validate the basin vector layer + if not self._validate_vector_layer(basin_layer, "basins"): + raise ValueError("Basin vector layer validation failed.") + + # Validate the centroid vector layer + if not self._validate_vector_layer(centroid_layer, "centroids"): + raise ValueError("Centroid vector layer validation failed.") + basins = [] - # Precompute Shapely polygons for all basins - basin_geometries = [basin.shapely_geometry(j) for j in range(len(basin))] + basin_geometries = [] + + # Precompute OGR geometries for all basins + for j in range(len(basin_layer)): + basin_attributes = basin_layer.get_attributes(j) + geometry_coords = basin_layer.get_geometry(j) # List of (x, y) tuples + + # Create OGR Polygon geometry using the geometry function + polygon = create_polygon(geometry_coords) + + if is_geometry_empty(polygon): + logging.warning(f"Empty geometry at Basins index {j}. Skipping.") + continue + + # Store the OGR geometry and attributes + basin_geometries.append((polygon, basin_attributes)) + + for i in range(len(centroid_layer)): + centroid_attributes = centroid_layer.get_attributes(i) + centroid_coords = centroid_layer.get_geometry(i) # List of (x, y) tuples + + if not centroid_coords: + logging.warning(f"Empty geometry at Centroids index {i}. Skipping.") + continue + + # Create OGR Point geometry using the geometry function + x, y = centroid_coords[0] + point = create_point(x, y) - for i in range(len(centroid)): - centroid_geom = centroid.shapely_geometry(i) - centroid_point = centroid_geom.centroid # Shapely Point object matching_basins = [] # Find all basins that contain the centroid point - for j, basin_geom in enumerate(basin_geometries): - if basin_geom.contains(centroid_point): - matching_basins.append(j) + for j, (basin_geom, basin_attributes) in enumerate(basin_geometries): + if contains(basin_geom, point): + matching_basins.append((j, basin_geom, basin_attributes)) if not matching_basins: - logging.warning( - f"Centroid ID {centroid.record(i)['id']} at ({centroid_point.x}, {centroid_point.y}) " - f"is not contained within any basin polygon." - ) + centroid_id = centroid_attributes.get("id", f"Index {i}") + logging.warning(f"Centroid ID {centroid_id} at ({x}, {y}) is not contained within any basin polygon.") continue # Skip this centroid or handle as needed if len(matching_basins) > 1: + centroid_id = centroid_attributes.get("id", f"Index {i}") logging.error( - f"Centroid ID {centroid.record(i)['id']} at ({centroid_point.x}, {centroid_point.y}) " - f"is contained within multiple basins: {matching_basins}. " + f"Centroid ID {centroid_id} at ({x}, {y}) " + f"is contained within multiple basins: {[idx for idx, _, _ in matching_basins]}. " f"Associating with the first matching basin." ) # Associate with the first matching basin - associated_basin_idx = matching_basins[0] - associated_basin_geom = basin_geometries[associated_basin_idx] - # Area in the units of the shapefile's projection - a = associated_basin_geom.area - r = centroid.record(i) - fi = r["fi"] - p = centroid_geom.centroid.coords[0] - basins.append(Basin(r["id"], p[0], p[1], (a / 1e6), fi)) + associated_basin_idx, associated_basin_geom, associated_basin_attributes = matching_basins[0] - return basins + # Calculate area using the geometry function + area = calculate_area(associated_basin_geom) - def confluence(self, confluence: VectorLayer) -> list: - """Build the confluence objects + # Convert area to square kilometers if necessary (depends on CRS units) + # For example, if units are in meters: + area_km2 = area / 1e6 # Convert from square meters to square kilometers - Parameters - ---------- - confluence : VectorLayer - The vector layer the confluences are on. + try: + basin_id = centroid_attributes["id"] + fi = centroid_attributes["fi"] + basins.append(Basin(basin_id, x, y, area_km2, fi)) + except KeyError as e: + logging.error(f"Missing expected field {e} in Centroids record {i}.") + raise - Returns - ------- - list - A list of confluence objects. + logging.info(f"Successfully built {len(basins)} Basin objects.") + return basins + + def confluence(self, confluence_layer: VectorLayer) -> list[Confluence]: + """ + Build the confluence objects. """ + # Validate the confluence vector layer with specific 'out' field validation + if not self._validate_vector_layer( + confluence_layer, + "confluences", + specific_validations=[validate_confluences_out_field], + ): + raise ValueError("Confluence vector layer validation failed.") + confluences = [] - for i in range(len(confluence)): - s = confluence.geometry(i) - p = s[0] - r = confluence.record(i) - confluences.append(Confluence(r["id"], p[0], p[1], bool(r["out"]))) + for i in range(len(confluence_layer)): + attributes = confluence_layer.get_attributes(i) + geometry_coords = confluence_layer.get_geometry(i) # List of (x, y) tuples + + if not geometry_coords: + logging.warning(f"Empty geometry at Confluences index {i}. Skipping.") + continue + + # Assuming confluence is a point geometry with one coordinate + x, y = geometry_coords[0] + + try: + confluence_id = attributes["id"] + out_field = attributes["out"] + confluences.append(Confluence(confluence_id, x, y, bool(out_field))) + except KeyError as e: + logging.error(f"Missing expected field {e} in Confluences record {i}.") + raise + + logging.info(f"Successfully built {len(confluences)} Confluence objects.") return confluences diff --git a/src/pyromb/core/gis/vector_layer.py b/src/pyromb/core/gis/vector_layer.py index 89e3952..0cb21e5 100644 --- a/src/pyromb/core/gis/vector_layer.py +++ b/src/pyromb/core/gis/vector_layer.py @@ -1,18 +1,21 @@ - +# src\pyromb\core\gis\vector_layer.py import abc +from typing import Any +from osgeo import ogr + class VectorLayer(abc.ABC): """ - Interface for reading shapefiles. - - Used by the Builder to access the geometry and attributes of the - shapefile to build the catchment objects. Given the various ways a shapefile can - be read, the VectorLayer Class wrappes the functionality of reading the shapefile - by the chosen library in a consistent interface to be used by the builder. + Interface for reading shapefiles. + + Used by the Builder to access the geometry and attributes of the + shapefile to build the catchment objects. Given the various ways a shapefile can + be read, the VectorLayer Class wrappes the functionality of reading the shapefile + by the chosen library in a consistent interface to be used by the builder. """ - + @abc.abstractmethod - def geometry(self, i: int) -> list: + def geometry(self, i: int) -> list[tuple[float, float]]: """ Method to access the geometry of the ith vector in the shapefile. @@ -31,9 +34,9 @@ def geometry(self, i: int) -> list: pass @abc.abstractmethod - def record(self, i: int) -> dict: + def record(self, i: int) -> dict[str, Any]: """ - Method to access the attributes of the ith vector in the shapefile. + Method to access the attributes of the ith vector in the shapefile. Return the set of attributes as a dictionary. @@ -45,7 +48,7 @@ def record(self, i: int) -> dict: Returns ------- dict - key:value pair of the attributes. + key:value pair of the attributes. """ pass @@ -56,6 +59,44 @@ def __len__(self) -> int: Returns ------- int - Vectors in the shapefile. + Vectors in the shapefile. + """ + pass + + def get_geometry(self, i: int) -> list[tuple[float, float]]: + """ + Returns the geometry of the ith vector as a list of (x, y) tuples. + Default implementation assumes geometry(i) returns this format. + Subclasses can override this method if necessary. + """ + return self.geometry(i) + + def get_attributes(self, i: int) -> dict[str, Any]: """ - pass \ No newline at end of file + Returns the attributes of the ith vector as a dictionary. + Default implementation assumes record(i) returns a dict-like object. + Subclasses can override this method for custom behavior. + """ + record = self.record(i) + if isinstance(record, dict): + return record + else: + raise TypeError(f"Not expected type: {record}") + + def get_fields(self) -> list[tuple[str, int]]: + """ + Return field names and types for the vector layer. + + This method assumes the layer provides access to its field names and types. + Subclasses should implement this if it is required by the validation logic. + + Returns + ------- + List of tuples containing field names and types. + """ + raise NotImplementedError("Subclasses must implement get_fields() if needed.") + + # @abc.abstractmethod + # def get_ogr_geometry(self, i: int) -> ogr.Geometry: + # """Return the OGR geometry object for the ith feature.""" + # pass diff --git a/src/pyromb/core/traveller.py b/src/pyromb/core/traveller.py index bd4fc0d..55567b7 100644 --- a/src/pyromb/core/traveller.py +++ b/src/pyromb/core/traveller.py @@ -4,16 +4,19 @@ from .catchment import Catchment from ..model.model import Model +import logging + + class Traveller: """The Traveller walks through the catchment, proceeding from the very most upstream - basin to the outfall location. + basin to the outfall location. The walk is performed in a breadth first manner processing all the upstream catchments first, then walking down till it finds a confluence and jumps back up to the most - upstream sub-basin. So that RORB can be built correctly, the traveller has the option to - pause on the confluence before proceeding to the next upstream reach. This allows for a + upstream sub-basin. So that RORB can be built correctly, the traveller has the option to + pause on the confluence before proceeding to the next upstream reach. This allows for a save step to be performed in the RORB model. WBNM does not require such a step. - + Parameters ---------- catchment : Catchment @@ -34,38 +37,40 @@ def position(self) -> int: Returns ------- int - The current position of the traveller. + The current position of the traveller. """ return self._pos - + def getStart(self) -> int: - """Gets the position of the outlet node of the basin. - - That is the most downstream node. Assumes that there is only one + """Gets the position of the outlet node of the basin. + + That is the most downstream node. Assumes that there is only one outlet in a basin. i.e. only one node with no reaches downstream of it. Returns ------- int - The index of the outlet node. + The index of the outlet node. """ for i, val in enumerate(self._ds): if sum(val) == (-len(val)): return i - + else: + return -1 + def getReach(self, i: int) -> Reach: - """The downstream reach connected to ith node. + """The downstream reach connected to ith node. Parameters ---------- i : int - The index of the node we wish to get the downstream reach for. + The index of the node we wish to get the downstream reach for. Returns ------- Reach - The reach downstream of the ith node. + The reach downstream of the ith node. Raises ------ @@ -76,14 +81,14 @@ def getReach(self, i: int) -> Reach: if val != self._endSentinel: return self._catchment._edges[j] raise KeyError - + def getNode(self, i: int) -> Node: """The ith node. Parameters ---------- i : int - The index of the node to return. + The index of the node to return. Returns ------- @@ -93,7 +98,7 @@ def getNode(self, i: int) -> Node: return self._catchment._vertices[i] def top(self, i: int) -> int: - """ The node index of the most upstream catchment avaiable from node i. + """The node index of the most upstream catchment avaiable from node i. Does not update the position of the traveller, that is the traveller does not travel to this node. An avaiable catchment is one which has not been visited by _next(). @@ -101,12 +106,12 @@ def top(self, i: int) -> int: Parameters ---------- i : int - The position to query the most upstream catchment from. + The position to query the most upstream catchment from. Returns ------- int - The index of the node. + The index of the node. """ for val in self._us[i]: if val != -1: @@ -115,16 +120,16 @@ def top(self, i: int) -> int: else: continue return i - + def up(self, i: int) -> list: - """Returns the immediate upstream nodes from position i. - + """Returns the immediate upstream nodes from position i. + A subarea can have multiple upstream nodes and so will return all of them. Parameters ---------- i : int - the position which the upstream nodes are to be queried. + the position which the upstream nodes are to be queried. Returns ------- @@ -136,13 +141,13 @@ def up(self, i: int) -> list: def down(self, i: int) -> int: """The index of the immediate downstream node along the reach. - Does not update the position of the traveller. If there is no downstream + Does not update the position of the traveller. If there is no downstream node then -1 is returned. Parameters ---------- i : int - The position at which the downstream nodes are to be queried from. + The position at which the downstream nodes are to be queried from. Returns ------- @@ -153,20 +158,23 @@ def down(self, i: int) -> int: if val != -1: return val return self._endSentinel - + def next(self) -> int: - """The next upstream node within the catchment available from the current position. - - If the current position is on a confluence next() will return that node before - going up the next reach. This is due to RORB needing to save the state at that - confluence before calculating the hydrographs of upstream sub basins. - + """The next upstream node within the catchment available from the current position. + + If the current position is on a confluence next() will return that node before + going up the next reach. This is due to RORB needing to save the state at that + confluence before calculating the hydrographs of upstream sub basins. + Returns ------- int - The index of the upstream node. + The index of the upstream node. """ + top = self.top(self._pos) + logging.debug(f"Traveller moving from position {self._pos} to {top}") + top = self.top(self._pos) if top == self._pos: self._colour[self._pos] = 1 @@ -175,16 +183,16 @@ def next(self) -> int: else: self._pos = top return self._pos - + def nextAbsolute(self) -> int: """The absolute upper most node availabe from the current position. nextAbsolute ignores the intermediate confluences and moves to the most - upstream point in the catchment that hasn't been visited. If it can reach a - higher node it will travel to this node and return it. next must be called + upstream point in the catchment that hasn't been visited. If it can reach a + higher node it will travel to this node and return it. next must be called prior to using nextAbsolute as it assumes the traveller started at a node - with no reaches above it, that is, at the top of the catchment. - + with no reaches above it, that is, at the top of the catchment. + Returns ------- int @@ -205,14 +213,14 @@ def nextAbsolute(self) -> int: return self._pos def getVector(self, model: Model) -> str: - """Produce the vector for the desired hydrology model. - - Supports either RORB or WBNM. + """Produce the vector for the desired hydrology model. + + Supports either RORB or WBNM. Parameters ---------- model : Model - The hydrology model to generate the control file for. + The hydrology model to generate the control file for. Returns ------- @@ -220,4 +228,4 @@ def getVector(self, model: Model) -> str: The control file string. """ - return model.getVector(self) \ No newline at end of file + return model.getVector(self) diff --git a/src/pyromb/math/geometry.py b/src/pyromb/math/geometry.py index cace2f8..124cdf0 100644 --- a/src/pyromb/math/geometry.py +++ b/src/pyromb/math/geometry.py @@ -1,82 +1,392 @@ +# src/pyromb/math/geometry.py import math -from ..core.geometry.point import Point +from typing import Union +from osgeo import ogr +from ..core.geometry.point import Point # Ensure correct import path +import logging -def length(vertices:list) -> float: - """Calculate the cartesian length of a vector of co-ordinates. + +class GeometryError(Exception): + """Custom exception for geometry-related errors.""" + + pass + + +def wkbFlatten(geometry_type: int) -> int: + return geometry_type & (~ogr.wkb25DBit) + + +def calculate_length(geometry: ogr.Geometry) -> float: + """ + Calculate the length of an OGR geometry. Parameters ---------- - vertices : list - The list of co-ordinates to calculate the length. + geometry : ogr.Geometry + The geometry to calculate the length for. Returns ------- float - The vector length. + The length of the geometry. + + Raises + ------ + GeometryError + If the geometry is None or empty. """ + if geometry is None or geometry.IsEmpty(): + raise GeometryError("Cannot calculate length: Geometry is None or empty.") + return geometry.Length() - length = 0 - for i in range(len(vertices) - 1): - length += math.sqrt( \ - math.pow((vertices[i+1].coordinates()[0] - vertices[i].coordinates()[0]), 2) + \ - math.pow((vertices[i+1].coordinates()[1] - vertices[i].coordinates()[1]), 2) - ) - return length -# Shoelace algorithm -def polygon_area(vertices:list) -> float: - """Calculate the cartesian area of a polygon. +def calculate_area(geometry: ogr.Geometry) -> float: + """ + Calculate the area of an OGR Polygon or MultiPolygon geometry. Parameters ---------- - vertices : list - A list of points representing the polygon. + geometry : ogr.Geometry + The polygon geometry to calculate the area for. Returns ------- float - The polygon area. + The area of the polygon. + + Raises + ------ + GeometryError + If the geometry is None, empty, or not a Polygon/MultiPolygon. """ + if geometry is None or geometry.IsEmpty(): + raise GeometryError("Cannot calculate area: Geometry is None or empty.") + geom_type = wkbFlatten(geometry.GetGeometryType()) + if geom_type not in [ogr.wkbPolygon, ogr.wkbMultiPolygon]: + raise GeometryError("Cannot calculate area: Geometry is not a Polygon or MultiPolygon.") + return geometry.GetArea() - psum = 0 - nsum = 0 - for i in range(len(vertices)): - sindex = (i + 1) % len(vertices) - prod = vertices[i].coordinates()[0] * vertices[sindex].coordinates()[1] - psum += prod - for i in range(len(vertices)): - sindex = (i + 1) % len(vertices) - prod = vertices[sindex].coordinates()[0] * vertices[i].coordinates()[1] - nsum += prod +def calculate_centroid(geometry: ogr.Geometry) -> Point: + """ + Calculate the centroid of an OGR Polygon or MultiPolygon geometry. - return abs(1/2*(psum - nsum)) + Parameters + ---------- + geometry : ogr.Geometry + The polygon geometry to calculate the centroid for. -def polygon_centroid(vertices:list) -> Point: - """Calculate the centroid of a polygon. + Returns + ------- + Point + The centroid as a Point object. + + Raises + ------ + GeometryError + If the geometry is None, empty, or not a Polygon/MultiPolygon. + If the centroid calculation fails. + """ + if geometry is None or geometry.IsEmpty(): + raise GeometryError("Cannot calculate centroid: Geometry is None or empty.") + geom_type = wkbFlatten(geometry.GetGeometryType()) + if geom_type not in [ogr.wkbPolygon, ogr.wkbMultiPolygon]: + raise GeometryError("Cannot calculate centroid: Geometry is not a Polygon or MultiPolygon.") + + centroid_geom = geometry.Centroid() + if centroid_geom is None or centroid_geom.IsEmpty(): + raise GeometryError("Centroid calculation failed: Resulting centroid geometry is None or empty.") + + return Point(x=centroid_geom.GetX(), y=centroid_geom.GetY()) + + +def calculate_distance(pt1: ogr.Geometry, pt2: ogr.Geometry) -> float: + """ + Calculate the Euclidean distance between two OGR point geometries. Parameters ---------- - vertices : list - A list of points representing the polygon. + pt1 : ogr.Geometry + The first point geometry. + pt2 : ogr.Geometry + The second point geometry. Returns ------- - Point - The centroid. + float + The Euclidean distance between pt1 and pt2. + + Raises + ------ + GeometryError + If either pt1 or pt2 is None, empty, or not a point geometry. + """ + if pt1 is None or pt1.IsEmpty(): + raise GeometryError("First point is None or empty.") + if pt2 is None or pt2.IsEmpty(): + raise GeometryError("Second point is None or empty.") + if wkbFlatten(pt1.GetGeometryType()) != ogr.wkbPoint: + raise GeometryError("First geometry is not a Point.") + if wkbFlatten(pt2.GetGeometryType()) != ogr.wkbPoint: + raise GeometryError("Second geometry is not a Point.") + + return math.sqrt((pt1.GetX() - pt2.GetX()) ** 2 + (pt1.GetY() - pt2.GetY()) ** 2) + + +def point_on_reference(pt: ogr.Geometry, ref_points: list[ogr.Geometry], tolerance: float = 1e-6) -> bool: + """ + Check if a point coincides with any reference point within a given tolerance. + + Parameters + ---------- + pt : ogr.Geometry + The point to check. + ref_points : List[ogr.Geometry] + The list of reference points. + tolerance : float, optional + The tolerance distance, by default 1e-6. + + Returns + ------- + bool + True if the point coincides with any reference point within the tolerance, False otherwise. + + Raises + ------ + GeometryError + If pt is None, empty, or not a point geometry. + If any reference point in ref_points is None, empty, or not a point geometry. + """ + if pt is None or pt.IsEmpty(): + raise GeometryError("Input point is None or empty.") + if wkbFlatten(pt.GetGeometryType()) != ogr.wkbPoint: + raise GeometryError("Input geometry is not a Point.") + + for ref_pt in ref_points: + if ref_pt is None or ref_pt.IsEmpty(): + raise GeometryError("Reference point is None or empty.") + if wkbFlatten(ref_pt.GetGeometryType()) != ogr.wkbPoint: + raise GeometryError("Reference geometry is not a Point.") + distance = calculate_distance(pt, ref_pt) + if distance <= tolerance: + return True + return False + + +def create_line_string(coords: list[tuple[float, float]]) -> ogr.Geometry: + """ + Create an OGR LineString geometry from a list of (x, y) tuples. + + Parameters + ---------- + coords : List[Tuple[float, float]] + The coordinates of the LineString vertices. + + Returns + ------- + ogr.Geometry + The created LineString geometry. + + Raises + ------ + GeometryError + If coords is empty or contains invalid coordinate tuples. + """ + if not coords: + raise GeometryError("Coordinate list is empty.") + + line = ogr.Geometry(ogr.wkbLineString) + for idx, (x, y) in enumerate(coords): + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + raise GeometryError(f"Invalid coordinates at index {idx}: ({x}, {y})") + line.AddPoint(x, y) + return line + + +def create_polygon(coords: list[tuple[float, float]]) -> ogr.Geometry: + """ + Create an OGR Polygon geometry from a list of (x, y) tuples. + + Parameters + ---------- + coords : List[Tuple[float, float]] + The coordinates of the Polygon vertices. The first and last points must be the same to close the ring. + + Returns + ------- + ogr.Geometry + The created Polygon geometry. + + Raises + ------ + GeometryError + If coords are insufficient to form a polygon or if the ring is not closed. + """ + if not coords: + raise GeometryError("Coordinate list is empty.") + if len(coords) < 4: + raise GeometryError("At least four coordinates are required to form a Polygon (including closure).") + if coords[0] != coords[-1]: + raise GeometryError("Polygon ring is not closed. First and last coordinates must be the same.") + + ring = ogr.Geometry(ogr.wkbLinearRing) + for idx, (x, y) in enumerate(coords): + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + raise GeometryError(f"Invalid coordinates at index {idx}: ({x}, {y})") + ring.AddPoint(x, y) + + polygon = ogr.Geometry(ogr.wkbPolygon) + polygon.AddGeometry(ring) + return polygon + + +def create_point(x: float, y: float) -> ogr.Geometry: + """ + Create an OGR Point geometry from x and y coordinates. + + Parameters + ---------- + x : float + The X coordinate. + y : float + The Y coordinate. + + Returns + ------- + ogr.Geometry + The created Point geometry. + + Raises + ------ + GeometryError + If x or y is not a valid number. + """ + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + raise GeometryError(f"Invalid coordinates: ({x}, {y})") + + point = ogr.Geometry(ogr.wkbPoint) + point.AddPoint(x, y) + return point + + +def is_geometry_empty(geom: ogr.Geometry) -> bool: + """ + Check if an OGR geometry is empty. + + Parameters + ---------- + geom : ogr.Geometry + The geometry to check. + + Returns + ------- + bool + True if the geometry is empty, False otherwise. + + Raises + ------ + GeometryError + If geom is None. + """ + if geom is None: + raise GeometryError("Geometry is None.") + return geom.IsEmpty() + + +def contains(polygon: ogr.Geometry, point: ogr.Geometry) -> bool: + """ + Check if a polygon contains a point. + + Parameters + ---------- + polygon : ogr.Geometry + The Polygon or MultiPolygon geometry. + point : ogr.Geometry + The Point geometry. + + Returns + ------- + bool + True if the polygon contains the point, False otherwise. + + Raises + ------ + GeometryError + If either polygon or point is None, empty, or of incorrect geometry types. """ - - sumx = 0 - sumy = 0 - suma = 0 + if polygon is None or polygon.IsEmpty(): + raise GeometryError("Polygon geometry is None or empty.") + if point is None or point.IsEmpty(): + raise GeometryError("Point geometry is None or empty.") + + # The wkbFlatten function is not directly available in the osgeo.ogr module in some versions of GDAL/OGR's Python + # bindings. This function is used in the C++ API but may not be exposed in the Python API. + polygon_geom_type = wkbFlatten(polygon.GetGeometryType()) + point_geom_type = wkbFlatten(point.GetGeometryType()) + + if polygon_geom_type not in [ogr.wkbPolygon, ogr.wkbMultiPolygon]: + raise GeometryError( + f"First geometry is not a Polygon or MultiPolygon: {ogr.GeometryTypeToName(polygon_geom_type)}" + ) + if point_geom_type != ogr.wkbPoint: + raise GeometryError(f"Second geometry is not a Point: {ogr.GeometryTypeToName(point_geom_type)}") + + # Handle MultiPolygon with multiple geometries + if polygon_geom_type == ogr.wkbMultiPolygon: + num_geoms = polygon.GetGeometryCount() + if num_geoms > 1: + logging.warning(f"Polygon is a MultiPolygon with {num_geoms} parts.") + elif num_geoms == 1: + logging.warning("Polygon is a MultiPolygon with a single part. Proceeding.") + polygon = polygon.GetGeometryRef(0) # Use the first (and only) polygon + + return polygon.Contains(point) + + +# Define a type alias for clarity +Coordinate = Union[tuple[float, float], list[float]] + + +def length(vertices: list[Coordinate]) -> float: + """ + Calculate the Cartesian length of a vector defined by a list of coordinates. + + Parameters + ---------- + vertices : List[Coordinate] + The list of coordinates to calculate the length. Each coordinate should be a tuple or list + containing at least two numerical values representing (x, y). + + Returns + ------- + float + The total Cartesian length of the vector. + + Raises + ------ + ValueError + If a vertex does not contain at least two numerical values. + TypeError + If the vertices are not provided as a list of tuples or lists. + """ + if not isinstance(vertices, list): + raise TypeError("vertices must be a list of coordinate tuples or lists.") + + total_length = 0.0 + for i in range(len(vertices) - 1): - p = [(vertices[i].coordinates()[0], vertices[i].coordinates()[1]), (vertices[i+1].coordinates()[0], vertices[i+1].coordinates()[1])] - sumx += (p[0][0] + p[1][0]) * (p[0][0] * p[1][1] - p[1][0] * p[0][1]) - sumy += (p[0][1] + p[1][1]) * (p[0][0] * p[1][1] - p[1][0] * p[0][1]) - suma += p[0][0] * p[1][1] - p[1][0] * p[0][1] - - A = 0.5 * suma - Cx = (1 / (6 * A)) * sumx - Cy = (1 / (6 * A)) * sumy - - return Point(Cx, Cy) \ No newline at end of file + current_vertex = vertices[i] + next_vertex = vertices[i + 1] + + # Ensure each vertex has at least two elements + if not (isinstance(current_vertex, (tuple, list)) and isinstance(next_vertex, (tuple, list))): + raise TypeError("Each vertex must be a tuple or list containing numerical coordinates.") + if len(current_vertex) < 2 or len(next_vertex) < 2: + raise ValueError("Each vertex must contain at least two numerical values for (x, y).") + + dx: float = next_vertex[0] - current_vertex[0] + dy: float = next_vertex[1] - current_vertex[1] + segment_length: float = math.hypot(dx, dy) + total_length += segment_length + + return total_length diff --git a/src/pyromb/model/rorb.py b/src/pyromb/model/rorb.py index aec6dec..f895e77 100644 --- a/src/pyromb/model/rorb.py +++ b/src/pyromb/model/rorb.py @@ -1,3 +1,4 @@ +# src\pyromb\model\rorb.py from .model import Model from ..core.traveller import Traveller from ..core.attributes.basin import Basin @@ -7,7 +8,10 @@ import json import os -class VectorBlock(): +import logging + + +class VectorBlock: """ Builds the vector block for the RORB control file. """ @@ -18,45 +22,47 @@ def __init__(self) -> None: self._stateVector = [] self._controlVector = [] - resources_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources') - with open(os.path.join(resources_dir, 'formatting.json'), 'r') as f: + resources_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") + with open(os.path.join(resources_dir, "formatting.json"), "r") as f: self._formattingOptions = json.load(f) - + def step(self, traveller: Traveller) -> None: - """ + """ Calculate action to take at the current step and store it in the VectorBlock's state. - - Step is to be used at every update of the Traveller. The RORB control vector + + Step is to be used at every update of the Traveller. The RORB control vector is then built from the VectorBlock's state after the catchment has been traversed. Parameters ---------- traveller : Traveller - The traveller for the catchment being built. + The traveller for the catchment being built. """ self._state(traveller) self._control(self._stateVector[-1], traveller) - + def build(self, traveller: Traveller) -> str: - """ + """ Builds the vector block string. - + Parameters ---------- traveller: Traveller The traveller that traversed the catchment. - + Returns ------- str - The vector block string to be used in the .catg file + The vector block string to be used in the .catg file """ - vectorStr = "0\n" # Start with code 0, reach types are specified in the control block. + vectorStr = "0\n" # Start with code 0, reach types are specified in the control block. for s in self._controlVector: vectorStr += f"{s}\n" - vectorStr += f"{self._subAreaStr(self._stateVector, traveller)}\n{self._fracImpStr(self._stateVector, traveller)}\n" + vectorStr += ( + f"{self._subAreaStr(self._stateVector, traveller)}\n{self._fracImpStr(self._stateVector, traveller)}\n" + ) return vectorStr def _state(self, traveller: Traveller) -> None: @@ -78,32 +84,82 @@ def _state(self, traveller: Traveller) -> None: The traveller traversing this catchment. """ + # Retrieve current position and upstream position i = traveller._pos up = traveller.top(i) - - if i == traveller._endSentinel: - ret = (0, i) - elif (self._runningHydro == False) and (isinstance(traveller._catchment._vertices[i], Basin)): - self._runningHydro = True - traveller.next() - ret = (1, i) - elif (self._storedHydro) and (self._storedHydro[-1] == i) and (self._runningHydro): - self._storedHydro.pop() - ret = (4, i) - elif (self._runningHydro) and (isinstance(traveller._catchment._vertices[i], Basin)) and (up == i): - traveller.next() - ret = (2, i) - elif (self._runningHydro) and (up != i): - self._storedHydro.append(i) - self._runningHydro = False - traveller.next() - ret = (3, i) - elif (self._runningHydro) and (isinstance(traveller._catchment._vertices[i], Confluence)) and (up == i): - traveller.next() - ret = (5, i) - - self._stateVector.append(ret) - + + # Log the current state + logging.debug("--- _state Method Invocation ---") + logging.debug(f"Current Position (_pos): {i}") + logging.debug(f"Top Position (traveller.top({i})): {up}") + logging.debug(f"Running Hydrograph (_runningHydro): {self._runningHydro}") + logging.debug(f"Stored Hydrograph Stack (_storedHydro): {self._storedHydro}") + + # Identify the node type + current_node = traveller._catchment._vertices[i] + node_type = type(current_node).__name__ + logging.debug(f"Current Node Type: {node_type}") + + try: + if i == traveller._endSentinel: + ret = (0, i) + logging.debug(f"Reached end sentinel. Setting ret to {ret}") + + elif (not self._runningHydro) and isinstance(current_node, Basin): + self._runningHydro = True + logging.debug(f"Node {i} is a Basin and no running hydrograph. Setting _runningHydro to True.") + traveller.next() + ret = (1, i) + logging.debug(f"Moved to next node. Setting ret to {ret}") + + elif (self._storedHydro) and (self._storedHydro[-1] == i) and self._runningHydro: + popped = self._storedHydro.pop() + logging.debug(f"Node {i} has a stored hydrograph. Popped {popped} from _storedHydro.") + ret = (4, i) + logging.debug(f"Setting ret to {ret}") + + elif (self._runningHydro) and isinstance(current_node, Basin) and (up == i): + logging.debug( + f"Node {i} is a Basin with running hydrograph and no upstream reaches. Moving to next node." + ) + traveller.next() + ret = (2, i) + logging.debug(f"Moved to next node. Setting ret to {ret}") + + elif (self._runningHydro) and (up != i): + self._storedHydro.append(i) + self._runningHydro = False + logging.debug( + f"Node {i} has upstream reaches. Appended to _storedHydro and set _runningHydro to False." + ) + traveller.next() + ret = (3, i) + logging.debug(f"Moved to next node. Setting ret to {ret}") + + elif (self._runningHydro) and isinstance(current_node, Confluence) and (up == i): + logging.debug( + f"Node {i} is a Confluence with running hydrograph and no upstream reaches. Moving to next node." + ) + traveller.next() + ret = (5, i) + logging.debug(f"Moved to next node. Setting ret to {ret}") + + else: + logging.error( + f"Unhandled state in _state method: " + f"Position={i}, Top={up}, _runningHydro={self._runningHydro}, " + f"_storedHydro={self._storedHydro}, Node Type={node_type}" + ) + raise ValueError("Incorrect value passed - not sure of the cause") + + # Append the result to the state vector + self._stateVector.append(ret) + logging.debug(f"Appended ret {ret} to _stateVector.") + + except Exception as e: + logging.exception(f"Exception occurred in _state method: {e}") + raise + def _control(self, code: tuple, traveller: Traveller) -> None: """ Format a control vector string according to the RORB manual Table 5-1 p.52 (version 6). @@ -120,25 +176,28 @@ def _control(self, code: tuple, traveller: Traveller) -> None: The traveller traversing this catchment. """ + # initialise with + ret = "ERROR if you see this" + if code[0] in (1, 2, 5): try: r = traveller.getReach(code[1]) - if (r.type == ReachType.NATURAL) or (r.type == ReachType.DROWNED): - ret = f"{code[0]},{r.type.value},{r.length() / 1000:.3f},-99" + if (r.reachType == ReachType.NATURAL) or (r.reachType == ReachType.DROWNED): + ret = f"{code[0]},{r.reachType.value},{r.length / 1000:.3f},-99" else: - ret = f"{code[0]},{r.type.value},{r.length() / 1000:.3f},{r.getSlope()},-99" + ret = f"{code[0]},{r.reachType.value},{r.length / 1000:.3f},{r.slope},-99" except: ret = f"{7}\n\n{0}" - + if (code[0] == 3) or (code[0] == 4): ret = f"{code[0]}" - - if (code[0] == 0): + + if code[0] == 0: ret = f"{7}\n\n'{0}" - + self._controlVector.append(ret) - - def _subAreaStr(self, code: tuple, traveller: Traveller) -> str: + + def _subAreaStr(self, code: list, traveller: Traveller) -> str: """ Format the subarea string according to the RORB manual. @@ -162,17 +221,16 @@ def _subAreaStr(self, code: tuple, traveller: Traveller) -> str: areaStr = "" for c in code: if (c[0] == 1) or (c[0] == 2): - areaStr += f"{traveller._catchment._vertices[c[1]].area:{self._formattingOptions['area_table']['percision']}}," - areaStr += '-99' + areaStr += ( + f"{traveller._catchment._vertices[c[1]].area:{self._formattingOptions['area_table']['percision']}}," + ) + areaStr += "-99" - values = areaStr.split(',') - formatted_values = ( - f"{resources.rorb.AREA_TABLE_HEADER}" - f"{self._makeTable(values, 'area_table')}" - ) + values = areaStr.split(",") + formatted_values = f"{resources.rorb.AREA_TABLE_HEADER}" f"{self._makeTable(values, 'area_table')}" return formatted_values - + def _fracImpStr(self, code: list, traveller: Traveller) -> str: """ Format the fraction impervious string according to the RORB manual. @@ -198,16 +256,13 @@ def _fracImpStr(self, code: list, traveller: Traveller) -> str: for c in code: if (c[0] == 1) or (c[0] == 2): fStr += f"{traveller._catchment._vertices[c[1]].fi:{self._formattingOptions['fi_table']['percision']}}," - fStr += ' -99' + fStr += " -99" - values = fStr.split(',') - formatted_values = ( - f"{values[0]} ,\n" - f"{self._makeTable(values[1:], 'fi_table')}" - ) + values = fStr.split(",") + formatted_values = f"{values[0]} ,\n" f"{self._makeTable(values[1:], 'fi_table')}" return formatted_values - + def _makeTable(self, value: list, table: str) -> str: """ Format a table string according to the RORB manual. @@ -233,17 +288,18 @@ def _makeTable(self, value: list, table: str) -> str: formatted_values += f"{val:{self._formattingOptions[table]['column_width']}}," formatted_values += f"\n{value[-1]}" - return formatted_values - + return formatted_values + @property def state(self): - return (self._stateVector) - + return self._stateVector + @state.setter def state(self): - raise AttributeError('State vector is generated not set') + raise AttributeError("State vector is generated not set") + -class GraphicsBlock(): +class GraphicsBlock: """ Builds the graphics block for the RORB control file. """ @@ -255,8 +311,8 @@ def __init__(self) -> None: self._nodeID = self._idGenerator() self._reachID = self._idGenerator() - resources_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources') - with open(os.path.join(resources_dir, 'formatting.json'), 'r') as f: + resources_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") + with open(os.path.join(resources_dir, "formatting.json"), "r") as f: self._formattingOptions = json.load(f) def step(self, code: tuple, traveller: Traveller) -> None: @@ -277,10 +333,10 @@ def step(self, code: tuple, traveller: Traveller) -> None: self._nodeDisplay(code, traveller) self._reachDisplay(code, traveller) - + def build(self) -> str: """Build the graphical block string for the .catg file. - + Returns ------- str @@ -296,7 +352,7 @@ def build(self) -> str: f"{self._generateNodeString()}" f"{resources.rorb.LEADING_TOKEN}\n" f"{self._generateReachString()}" - f"{resources.rorb.GRAPHICAL_TAIL}" + f"{resources.rorb.GRAPHICAL_TAIL}" ) return graphicalStr @@ -304,7 +360,7 @@ def build(self) -> str: def _replaceIDTags(self, vector: list) -> None: """ Replace the ID tags in the vector with the ID generated by the ID generator. - + Parameters ---------- vector : list @@ -328,19 +384,19 @@ def _normalizeCoordinates(self, scale: float = 90.0, shift: float = 2.5) -> None The shift factor to apply to the coordinates. """ - xs = [row['x'] for row in self._nodeVector] - ys = [row['y'] for row in self._nodeVector] + xs = [row["x"] for row in self._nodeVector] + ys = [row["y"] for row in self._nodeVector] scale_x = max(xs) - min(xs) scale_y = max(ys) - min(ys) for i, row in enumerate(self._nodeVector): - self._nodeVector[i]['x'] = (row['x'] - min(xs)) / scale_x * scale + shift - self._nodeVector[i]['y'] = (row['y'] - min(ys)) / scale_y * scale + shift + self._nodeVector[i]["x"] = (row["x"] - min(xs)) / scale_x * scale + shift + self._nodeVector[i]["y"] = (row["y"] - min(ys)) / scale_y * scale + shift for i, row in enumerate(self._reachVector): - self._reachVector[i]['x'] = (row['x'] - min(xs)) / scale_x * scale + shift - self._reachVector[i]['y'] = (row['y'] - min(ys)) / scale_y * scale + shift - + self._reachVector[i]["x"] = (row["x"] - min(xs)) / scale_x * scale + shift + self._reachVector[i]["y"] = (row["y"] - min(ys)) / scale_y * scale + shift + def _generateNodeString(self) -> str: """ Generates the display information string for the nodes. @@ -351,15 +407,15 @@ def _generateNodeString(self) -> str: nodeStr = resources.rorb.NODE_HEADER nodeStr += f"{resources.rorb.LEADING_TOKEN}{len(self._nodeVector):>7}\n" - + for row in self._nodeVector: nodeStr += resources.rorb.LEADING_TOKEN for item in row: - nodeStr += f"{row[item]:{self._formattingOptions['node'][item]}}" + nodeStr += f"{row[item]:{self._formattingOptions['node'][item]}}" nodeStr += f"\n{resources.rorb.LEADING_TOKEN}\n" - + return nodeStr - + def _generateReachString(self) -> str: """ Generates the display information string for the reaches. @@ -374,8 +430,8 @@ def _generateReachString(self) -> str: for row in self._reachVector: reachStr += resources.rorb.LEADING_TOKEN for item in row: - if (item == 'x') or (item == 'y'): - reachStr += f"\n{resources.rorb.LEADING_TOKEN}" + if (item == "x") or (item == "y"): + reachStr += f"\n{resources.rorb.LEADING_TOKEN}" reachStr += f"{row[item]:{self._formattingOptions['reach'][item]}}" reachStr += "\n" @@ -407,25 +463,25 @@ def _nodeDisplay(self, code: tuple, traveller: Traveller) -> None: ds_node = traveller.getNode(traveller.down(pos)) ds_name = f"<{ds_node.name}>" - + # Order according to the column order in the control vector. data = { - 'id': f"<{node.name}>", - 'x': x, - 'y': y, - 'icon': 1, - 'basin': int(isinstance(node, Basin)), - 'end': int(node.isOut) if isinstance(node, Confluence) else 0, - 'ds': ds_name, - 'name': f" {node.name}", - 'area': node.area if isinstance(node, Basin) else 0, - 'fi': node.fi if isinstance(node, Basin) else 0, - 'print': prnt, - 'excess': 0, - 'comment': 0 + "id": f"<{node.name}>", + "x": x, + "y": y, + "icon": 1, + "basin": int(isinstance(node, Basin)), + "end": int(node.isOut) if isinstance(node, Confluence) else 0, + "ds": ds_name, + "name": f" {node.name}", + "area": node.area if isinstance(node, Basin) else 0, + "fi": node.fi if isinstance(node, Basin) else 0, + "print": prnt, + "excess": 0, + "comment": 0, } - self._idMap[data['id']] = next(self._nodeID) + self._idMap[data["id"]] = next(self._nodeID) self._nodeVector.append(data) def _reachDisplay(self, code: tuple, traveller: Traveller) -> None: @@ -457,27 +513,27 @@ def _reachDisplay(self, code: tuple, traveller: Traveller) -> None: # Order according to the column order in the control vector. data = { - 'id': f"<{reach.name}>", - 'name': f" {reach.name}", - 'us': f"<{traveller.getNode(pos).name}>", - 'ds': f"<{traveller.getNode(traveller.down(pos)).name}>", - 'translation': 0, - 'type': reach.type.value, - 'print': 0, - 'length': reach.length() / 1000, - 'slope': reach.slope, - 'npoints': 1, - 'comment': 0, - 'x': x, - 'y': y, + "id": f"<{reach.name}>", + "name": f" {reach.name}", + "us": f"<{traveller.getNode(pos).name}>", + "ds": f"<{traveller.getNode(traveller.down(pos)).name}>", + "translation": 0, + "type": reach.reachType.value, + "print": 0, + "length": reach.length / 1000, + "slope": reach.slope, + "npoints": 1, + "comment": 0, + "x": x, + "y": y, } - self._idMap[data['id']] = next(self._reachID) + self._idMap[data["id"]] = next(self._reachID) self._reachVector.append(data) except KeyError: pass - + @staticmethod def _idGenerator(): """ @@ -489,21 +545,22 @@ def _idGenerator(): i += 1 yield i + class RORB(Model): """ - Create a RORB GE control vector for input to the RORB runoff routing model. + Create a RORB GE control vector for input to the RORB runoff routing model. """ - + def __init__(self): pass - + def getVector(self, traveller: Traveller) -> str: traveller.next() vectorBlock = VectorBlock() graphicBlock = GraphicsBlock() - while(traveller._pos != traveller._endSentinel): + while traveller._pos != traveller._endSentinel: vectorBlock.step(traveller) graphicBlock.step(vectorBlock.state[-1], traveller) - return graphicBlock.build() + vectorBlock.build(traveller) \ No newline at end of file + return graphicBlock.build() + vectorBlock.build(traveller) diff --git a/src/pyromb/resources/expected_fields.json b/src/pyromb/resources/expected_fields.json index f5d4e20..8bbabc7 100644 --- a/src/pyromb/resources/expected_fields.json +++ b/src/pyromb/resources/expected_fields.json @@ -2,36 +2,36 @@ "reaches": [ { "name": "t", - "type": "N" + "type": "Integer64" }, { "name": "s", - "type": "N" + "type": "Real" }, { "name": "id", - "type": "C" + "type": "String" } ], "basins": [], "centroids": [ { "name": "id", - "type": "C" + "type": "String" }, { "name": "fi", - "type": "N" + "type": "Real" } ], "confluences": [ { "name": "id", - "type": "C" + "type": "String" }, { "name": "out", - "type": "N" + "type": "Integer64" } ] } \ No newline at end of file diff --git a/src/serialise.py b/src/serialise.py index 4e362e3..b6001ed 100644 --- a/src/serialise.py +++ b/src/serialise.py @@ -11,6 +11,9 @@ # Set the default suffix suffix_item = "_sample_new.json" +# this script was used to check outputs when I broke something +# Might be able to use it to help with unit tests, or pickle things instead + def serialize_to_json(data, filename: str, suffix: str = suffix_item) -> None: """ diff --git a/src/sf_vector_layer.py b/src/sf_vector_layer.py new file mode 100644 index 0000000..66b1090 --- /dev/null +++ b/src/sf_vector_layer.py @@ -0,0 +1,135 @@ +# src/sf_vector_layer.py + +import pyromb +from osgeo import ogr + + +class SFVectorLayer(pyromb.VectorLayer): + """ + Wrap the OGR layer with the necessary interface to work with the Builder. + """ + + def __init__(self, path: str) -> None: + """ + Initialize the SFVectorLayer with the given shapefile path. + + Parameters + ---------- + path : str + The path to the shapefile. + """ + self.path = path + self.driver = ogr.GetDriverByName("ESRI Shapefile") + self.datasource = self.driver.Open(path, 0) # 0 means read-only + if self.datasource is None: + raise FileNotFoundError(f"Could not open shapefile: {path}") + self.layer = self.datasource.GetLayer() + + def geometry(self, i: int) -> list: + """ + Retrieve the geometry points for the ith feature. + + Parameters + ---------- + i : int + The index of the feature. + + Returns + ------- + list + A list of (x, y) tuples representing the geometry. + """ + feature = self.layer.GetFeature(i) + if feature is None: + raise IndexError(f"Feature {i} not found in shapefile.") + geom = feature.GetGeometryRef() + if geom is None: + raise ValueError(f"Feature {i} has no geometry.") + + geom_type = geom.GetGeometryType() + points = [] + + if geom_type == ogr.wkbPoint: + points.append((geom.GetX(), geom.GetY())) + elif geom_type in [ogr.wkbLineString, ogr.wkbMultiLineString]: + points = geom.GetPoints() + elif geom_type in [ogr.wkbPolygon, ogr.wkbMultiPolygon]: + # Extract points from the exterior ring + ring = geom.GetGeometryRef(0) + if ring: + points = ring.GetPoints() + else: + raise ValueError(f"Polygon geometry at feature {i} has no exterior ring.") + else: + raise ValueError(f"Unsupported geometry type: {geom_type} at feature {i}") + + return points + + def record(self, i: int) -> dict: + """ + Retrieve the attributes for the ith feature. + + Parameters + ---------- + i : int + The index of the feature. + + Returns + ------- + dict + A dictionary of attribute names and their corresponding values. + """ + feature = self.layer.GetFeature(i) + if feature is None: + raise IndexError(f"Feature {i} not found in shapefile.") + + field_count = feature.GetFieldCount() + fields = [feature.GetFieldDefnRef(j).GetName() for j in range(field_count)] + values = [feature.GetField(j) for j in range(field_count)] + + return dict(zip(fields, values)) + + def __len__(self) -> int: + """ + Get the number of features in the shapefile. + + Returns + ------- + int + The total number of features. + """ + return self.layer.GetFeatureCount() + + def __del__(self): + """ + Destructor to clean up the OGR datasource. + """ + if self.datasource: + self.datasource.Release() + + def get_fields(self) -> list[tuple[str, int]]: + """ + Return field names and types for the vector layer. + + Returns + ------- + List[Tuple[str, int]] + A list of tuples containing field names and their OGR field type codes. + """ + fields = [] + layer_defn = self.layer.GetLayerDefn() + for i in range(layer_defn.GetFieldCount()): + field_defn = layer_defn.GetFieldDefn(i) + field_name = field_defn.GetName() + field_type_code = field_defn.GetType() + fields.append((field_name, field_type_code)) + return fields + + def get_ogr_geometry(self, i: int) -> ogr.Geometry: + feature = self.layer.GetFeature(i) + if feature is None: + raise IndexError(f"Feature {i} not found in layer.") + geometry = feature.GetGeometryRef() + if geometry is None: + raise ValueError(f"No geometry found for feature {i}.") + return geometry.Clone() diff --git a/vector.catg b/vector.catg deleted file mode 100644 index 764986d..0000000 --- a/vector.catg +++ /dev/null @@ -1,281 +0,0 @@ -REACH -C RORB_GE 6.45 -C WARNING - DO NOT EDIT THIS FILE OUTSIDE RORB TO ENSURE BOTH GRAPHICAL AND CATCHMENT DATA ARE COMPATIBLE WITH EACH OTHER -C THIS FILE CANNOT BE OPENED IN EARLIER VERSIONS OF RORB GE - CURRENT VERSION IS v6.45 -C -C REACH -C -C #FILE COMMENTS -C 0 -C -C #SUB-AREA AREA COMMENTS -C 0 -C -C #IMPERVIOUS FRACTION COMMENTS -C 0 -C -C #BACKGROUND IMAGE -C T F -C -C #NODES -C 34 -C 1 2.500 31.335 1.000 1 0 34 A 0.154586 0.100000 0 0 0 -C -C 2 13.116 49.997 1.000 1 0 3 C 0.168503 0.100000 0 0 0 -C -C 3 7.507 36.151 1.000 1 0 33 B 0.150142 0.100000 0 0 0 -C -C 4 44.649 57.640 1.000 1 0 20 F 0.326750 0.100000 0 0 0 -C -C 5 56.285 60.612 1.000 1 0 19 G 0.254624 0.100000 0 0 0 -C -C 6 67.043 44.993 1.000 1 0 16 M 0.344693 0.100000 0 0 0 -C -C 7 79.201 92.500 1.000 1 0 9 I 0.525479 0.100000 0 0 0 -C -C 8 92.500 79.214 1.000 1 0 9 J 0.931558 0.100000 0 0 0 -C -C 9 71.774 78.340 1.000 0 0 10 11 0.000000 0.000000 0 0 0 -C -C 10 68.805 70.548 1.000 1 0 15 H 0.537064 0.100000 0 0 0 -C -C 11 80.835 37.700 1.000 1 0 13 O 0.365812 0.100000 0 0 0 -C -C 12 87.945 50.095 1.000 1 0 13 P 0.965308 0.100000 0 0 0 -C -C 13 72.294 54.930 1.000 0 0 14 12 0.000000 0.000000 0 0 0 -C -C 14 69.068 56.558 1.000 1 0 15 N 0.130634 0.100000 0 0 0 -C -C 15 63.033 61.593 1.000 0 0 16 10 0.000000 0.000000 0 0 0 -C -C 16 59.851 56.619 1.000 0 0 19 9 0.000000 0.000000 0 0 0 -C -C 17 72.777 26.132 1.000 1 0 18 L 0.500843 0.100000 0 0 0 -C -C 18 58.318 44.364 1.000 1 0 19 K 0.162697 0.100000 0 0 0 -C -C 19 53.180 56.498 1.000 0 0 20 8 0.000000 0.000000 0 0 0 -C -C 20 41.067 61.232 1.000 0 0 21 7 0.000000 0.000000 0 0 0 -C -C 21 24.345 55.322 1.000 1 0 32 E 0.266448 0.100000 0 0 0 -C -C 22 68.056 2.500 1.000 1 0 24 T 0.395320 0.100000 0 0 0 -C -C 23 73.480 12.984 1.000 1 0 24 U 0.300901 0.100000 0 0 0 -C -C 24 60.617 18.114 1.000 0 0 25 6 0.000000 0.000000 0 0 0 -C -C 25 55.616 23.754 1.000 1 0 27 S 0.806644 0.100000 0 0 0 -C -C 26 45.118 33.106 1.000 1 0 27 V 0.167300 0.100000 0 0 0 -C -C 27 42.098 45.101 1.000 0 0 29 5 0.000000 0.000000 0 0 0 -C -C 28 36.590 41.486 1.000 1 0 29 R 0.331956 0.100000 0 0 0 -C -C 29 32.794 44.438 1.000 0 0 31 4 0.000000 0.000000 0 0 0 -C -C 30 31.375 49.577 1.000 1 0 31 Q 0.264499 0.100000 0 0 0 -C -C 31 18.310 41.815 1.000 0 0 32 3 0.000000 0.000000 0 0 0 -C -C 32 13.811 38.588 1.000 1 0 33 D 0.228319 0.100000 0 0 0 -C -C 33 8.764 24.870 1.000 0 0 34 2 0.000000 0.000000 0 0 0 -C -C 34 10.083 23.259 1.000 0 1 7 1 0.000000 0.000000 70 0 0 -C -C -C #REACHES -C 33 -C 1 A.2 1 34 0 1 0 0.579 0.000 1 0 -C 6.292 -C 27.297 -C 2 C.B 2 3 0 1 0 0.544 0.000 1 0 -C 10.312 -C 43.074 -C 3 B.2 3 33 0 1 0 0.495 0.000 1 0 -C 8.136 -C 30.510 -C 4 F.7 4 20 0 1 0 0.215 0.000 1 0 -C 42.858 -C 59.436 -C 5 G.10 5 19 0 1 0 0.202 0.000 1 0 -C 54.732 -C 58.555 -C 6 M.9 6 16 0 1 0 0.549 0.000 1 0 -C 63.447 -C 50.806 -C 7 I.11 7 9 0 1 0 0.691 0.000 1 0 -C 75.488 -C 85.420 -C 8 J.11 8 9 0 1 0 1.012 0.000 1 0 -C 82.137 -C 78.777 -C 9 11.H 9 10 0 1 0 0.321 0.000 1 0 -C 70.289 -C 74.444 -C 10 H.10 10 15 0 1 0 0.448 0.000 1 0 -C 65.919 -C 66.071 -C 11 O.12 11 13 0 1 0 0.767 0.000 1 0 -C 76.564 -C 46.315 -C 12 P.12 12 13 0 1 0 0.778 0.000 1 0 -C 80.119 -C 52.513 -C 13 12.N 13 14 0 1 0 0.162 0.000 1 0 -C 70.681 -C 55.744 -C 14 N.10 14 15 0 1 0 0.354 0.000 1 0 -C 66.051 -C 59.076 -C 15 10.9 15 16 0 1 0 0.251 0.000 1 0 -C 61.442 -C 59.106 -C 16 9.8 16 19 0 1 0 0.330 0.000 1 0 -C 56.516 -C 56.558 -C 17 L.K 17 18 0 1 0 0.964 0.000 1 0 -C 65.547 -C 35.241 -C 18 K.8 18 19 0 1 0 0.473 0.000 1 0 -C 55.749 -C 50.431 -C 19 8.7 19 20 0 1 0 0.707 0.000 1 0 -C 47.124 -C 58.865 -C 20 7.E 20 21 0 1 0 0.928 0.000 1 0 -C 32.706 -C 58.277 -C 21 E.D 21 32 0 1 0 0.750 0.000 1 0 -C 19.078 -C 46.955 -C 22 T.6 22 24 0 1 0 0.685 0.000 1 0 -C 64.337 -C 10.307 -C 23 U.6 23 24 0 1 0 0.652 0.000 1 0 -C 67.049 -C 15.549 -C 24 6.S 24 25 0 1 0 0.309 0.000 1 0 -C 58.117 -C 20.934 -C 25 S.5 25 27 0 1 0 1.112 0.000 1 0 -C 48.856 -C 34.428 -C 26 V.5 26 27 0 1 0 0.550 0.000 1 0 -C 43.608 -C 39.104 -C 27 5.4 27 29 0 1 0 0.569 0.000 1 0 -C 37.446 -C 44.769 -C 28 R.4 28 29 0 1 0 0.234 0.000 1 0 -C 34.692 -C 42.962 -C 29 4.3 29 31 0 1 0 0.781 0.000 1 0 -C 25.552 -C 43.126 -C 30 Q.3 30 31 0 1 0 0.761 0.000 1 0 -C 24.842 -C 45.696 -C 31 3.D 31 32 0 1 0 0.251 0.000 1 0 -C 16.061 -C 40.201 -C 32 D.2 32 33 0 1 0 0.532 0.000 1 0 -C 11.288 -C 31.729 -C 33 2.1 33 34 0 1 0 0.083 0.000 1 0 -C 9.424 -C 24.064 -C -C #STORAGES -C 0 -C -C #INFLOW/OUTFLOW -C 0 -C -C END RORB_GE -C -0 -1,1,0.579,-99 -3 -1,1,0.544,-99 -2,1,0.495,-99 -3 -1,1,0.215,-99 -3 -1,1,0.202,-99 -3 -1,1,0.549,-99 -3 -1,1,0.691,-99 -3 -1,1,1.012,-99 -4 -5,1,0.321,-99 -2,1,0.448,-99 -3 -1,1,0.767,-99 -3 -1,1,0.778,-99 -4 -5,1,0.162,-99 -2,1,0.354,-99 -4 -5,1,0.251,-99 -4 -5,1,0.330,-99 -4 -3 -1,1,0.964,-99 -2,1,0.473,-99 -4 -5,1,0.707,-99 -4 -5,1,0.928,-99 -2,1,0.750,-99 -3 -1,1,0.685,-99 -3 -1,1,0.652,-99 -4 -5,1,0.309,-99 -2,1,1.112,-99 -3 -1,1,0.550,-99 -4 -5,1,0.569,-99 -3 -1,1,0.234,-99 -4 -5,1,0.781,-99 -3 -1,1,0.761,-99 -4 -5,1,0.251,-99 -4 -2,1,0.532,-99 -4 -5,1,0.083,-99 -4 -7 - -0 -C Sub Area Data -C Areas, km**2, of subareas A,B... - 0.15459, 0.16850, 0.15014, 0.32675, 0.25462, - 0.34469, 0.52548, 0.93156, 0.53706, 0.36581, - 0.96531, 0.13063, 0.50084, 0.16270, 0.26645, - 0.39532, 0.30090, 0.80664, 0.16730, 0.33196, - 0.26450, 0.22832, --99 -C Impervious Fraction Data - 1 , - 0.10000, 0.10000, 0.10000, 0.10000, 0.10000, - 0.10000, 0.10000, 0.10000, 0.10000, 0.10000, - 0.10000, 0.10000, 0.10000, 0.10000, 0.10000, - 0.10000, 0.10000, 0.10000, 0.10000, 0.10000, - 0.10000, 0.10000, - -99 From c547d9e81b18e26e21cf4a073b1e1a065d7a4dcb Mon Sep 17 00:00:00 2001 From: Chain Frost Date: Sun, 20 Oct 2024 20:23:24 +0800 Subject: [PATCH 5/6] clean up imports --- src/app_testing.py | 1 - src/pyromb/core/attributes/reach.py | 2 +- src/pyromb/core/catchment.py | 1 - src/pyromb/core/geometry/line.py | 1 - src/pyromb/core/geometry/polygon.py | 3 +-- src/pyromb/core/geometry/shapefile_validation.py | 3 --- src/pyromb/core/gis/builder.py | 1 - src/pyromb/core/gis/vector_layer.py | 1 - src/pyromb/math/geometry.py | 2 +- src/serialise.py | 1 - 10 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/app_testing.py b/src/app_testing.py index b5f5c5a..6546e02 100644 --- a/src/app_testing.py +++ b/src/app_testing.py @@ -1,7 +1,6 @@ # app_testing.py import os import pyromb -from plot_catchment import plot_catchment import logging from sf_vector_layer import SFVectorLayer from app import main diff --git a/src/pyromb/core/attributes/reach.py b/src/pyromb/core/attributes/reach.py index 2a79d15..d2fde51 100644 --- a/src/pyromb/core/attributes/reach.py +++ b/src/pyromb/core/attributes/reach.py @@ -1,5 +1,5 @@ # src/pyromb/core/attributes/reach.py -from typing import Optional, Union, cast +from typing import Optional, cast from ..geometry.line import Line from enum import Enum from .node import Node diff --git a/src/pyromb/core/catchment.py b/src/pyromb/core/catchment.py index f155e83..1f0363e 100644 --- a/src/pyromb/core/catchment.py +++ b/src/pyromb/core/catchment.py @@ -8,7 +8,6 @@ from .attributes.confluence import Confluence from .attributes.node import Node from .attributes.reach import Reach -from ..math.geometry import length import math # Configure logging (adjust as needed or configure in a higher-level module) diff --git a/src/pyromb/core/geometry/line.py b/src/pyromb/core/geometry/line.py index 780eb5f..fd43060 100644 --- a/src/pyromb/core/geometry/line.py +++ b/src/pyromb/core/geometry/line.py @@ -1,6 +1,5 @@ # src/pyromb/core/geometry/line.py from typing import Optional, Iterator -from osgeo import ogr from ...math import geometry from .point import Point diff --git a/src/pyromb/core/geometry/polygon.py b/src/pyromb/core/geometry/polygon.py index 45b6466..94b4f44 100644 --- a/src/pyromb/core/geometry/polygon.py +++ b/src/pyromb/core/geometry/polygon.py @@ -1,6 +1,5 @@ # src/pyromb/core/geometry/polygon.py -from typing import List, Optional -from osgeo import ogr +from typing import Optional from ...math import geometry from .line import Line from .point import Point diff --git a/src/pyromb/core/geometry/shapefile_validation.py b/src/pyromb/core/geometry/shapefile_validation.py index 281e117..09e0f5a 100644 --- a/src/pyromb/core/geometry/shapefile_validation.py +++ b/src/pyromb/core/geometry/shapefile_validation.py @@ -2,9 +2,6 @@ import logging from osgeo import ogr # type: ignore from ..gis.vector_layer import VectorLayer -from ...math import geometry - -from osgeo import ogr def validate_shapefile_geometries(vector_layer: VectorLayer, layer_type: str) -> bool: diff --git a/src/pyromb/core/gis/builder.py b/src/pyromb/core/gis/builder.py index 9fa2a6c..3b39268 100644 --- a/src/pyromb/core/gis/builder.py +++ b/src/pyromb/core/gis/builder.py @@ -3,7 +3,6 @@ import json import logging from typing import Optional -from osgeo import ogr # type:ignore import importlib.resources from ..attributes.basin import Basin diff --git a/src/pyromb/core/gis/vector_layer.py b/src/pyromb/core/gis/vector_layer.py index 0cb21e5..0533ada 100644 --- a/src/pyromb/core/gis/vector_layer.py +++ b/src/pyromb/core/gis/vector_layer.py @@ -1,7 +1,6 @@ # src\pyromb\core\gis\vector_layer.py import abc from typing import Any -from osgeo import ogr class VectorLayer(abc.ABC): diff --git a/src/pyromb/math/geometry.py b/src/pyromb/math/geometry.py index 124cdf0..cca2d85 100644 --- a/src/pyromb/math/geometry.py +++ b/src/pyromb/math/geometry.py @@ -2,7 +2,7 @@ import math from typing import Union from osgeo import ogr -from ..core.geometry.point import Point # Ensure correct import path +from ..core.geometry.point import Point import logging diff --git a/src/serialise.py b/src/serialise.py index b6001ed..aa13e5c 100644 --- a/src/serialise.py +++ b/src/serialise.py @@ -2,7 +2,6 @@ import json import csv import logging -import pyromb from pyromb.core.attributes.basin import Basin from pyromb.core.attributes.confluence import Confluence from pyromb.core.attributes.reach import Reach From 47b8cb2cc4b79cfff5a583db5a6a730c3061232b Mon Sep 17 00:00:00 2001 From: Chain-Frost Date: Sun, 5 Oct 2025 22:42:28 +0800 Subject: [PATCH 6/6] docs: document dependencies and add agent guidance (#1) --- AGENTS.md | 11 +++++ documentation/explainers/ai_playbook.md | 28 +++++++++++++ documentation/explainers/change_review.md | 22 ++++++++++ documentation/explainers/overview.md | 25 +++++++++++ .../explainers/qgis_integration_plan.md | 42 +++++++++++++++++++ 5 files changed, 128 insertions(+) create mode 100644 AGENTS.md create mode 100644 documentation/explainers/ai_playbook.md create mode 100644 documentation/explainers/change_review.md create mode 100644 documentation/explainers/overview.md create mode 100644 documentation/explainers/qgis_integration_plan.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..3cbcedf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,11 @@ +# Agent Guidelines for Pyromb + +Welcome! When updating this repository, please review the explainer documents before making large changes: + +* `documentation/explainers/overview.md` – architecture and data flow summary. +* `documentation/explainers/qgis_integration_plan.md` – QGIS Processing integration scope and dependency expectations. +* `documentation/explainers/ai_playbook.md` – contribution workflow and testing guardrails for AI assistants. + +These explainers provide the context expected by the maintainers. Update them whenever the architecture or integration plan evolves. + +When adding new explainer-style documentation, keep it in `documentation/explainers/` unless maintainers request another location. diff --git a/documentation/explainers/ai_playbook.md b/documentation/explainers/ai_playbook.md new file mode 100644 index 0000000..410e0a2 --- /dev/null +++ b/documentation/explainers/ai_playbook.md @@ -0,0 +1,28 @@ +# AI Contributor Playbook + +Use this playbook when applying AI-assisted development to Pyromb. It distils the key conventions, guardrails, and exploratory steps that have helped previous iterations succeed. + +## Before writing code + +1. **Understand the target module** – Read the relevant files under `src/pyromb` (especially the `core` package) before proposing changes. The architecture overview in this folder explains the data flow and should be revisited when planning large updates. +2. **Confirm geometry expectations** – All geometry work happens through the abstractions in `core/geometry`. Avoid introducing new geometry libraries; reuse the existing point/line/polygon classes or extend them cautiously. +3. **Check dependency rules** – QGIS deployments expect only standard library, NumPy, and GDAL/OGR dependencies. If a feature appears to require Shapely or other heavy GIS libraries, reconsider the approach or design a fallback that uses native QGIS geometry APIs. +4. **Plan for serialisation** – Any new attributes that affect the output vector files must propagate through `Builder`, `Catchment`, and the relevant `model` writer. Sketch the end-to-end data path before coding to avoid partial integrations. + +## During implementation + +* **Reuse builders and validators** – Keep shapefile parsing inside the builder. If you need additional validation rules, add them to `core/geometry/shapefile_validation.py` so they apply consistently across entry points. +* **Maintain tolerance handling** – When matching coordinates, continue to use tolerance-based comparisons. If more precision is necessary, introduce configurable parameters rather than hard-coded thresholds. +* **Instrument with logging** – Use the `logging` module for standalone scripts and prepare to swap in QGIS `feedback` hooks when integrating into Processing. Avoid print statements. + +## Testing strategy + +1. **Unit coverage** – Prioritise pure Python units (geometry, maths, model writers) because they run reliably in CI and do not require QGIS. +2. **Integration checks** – For geometry or builder changes, run the helper script (`python -m src.app`) against the sample data in the `data` directory to confirm end-to-end output remains valid. +3. **Regression artefacts** – Use the serialisation helpers in `serialise.py` to capture JSON or CSV snapshots of intermediate structures when diagnosing behaviour differences between QGIS and standalone runs. + +## Follow-up questions to resolve + +* Do we need a dedicated fixture generation script to refresh sample shapefiles for new test cases? +* Should tolerance thresholds become part of a configuration file so they can be tuned per project? +* Would a lightweight CLI wrapper around the builder be useful for non-QGIS environments once the Processing integration is complete? diff --git a/documentation/explainers/change_review.md b/documentation/explainers/change_review.md new file mode 100644 index 0000000..0cd7ced --- /dev/null +++ b/documentation/explainers/change_review.md @@ -0,0 +1,22 @@ +# Review: docs commit vs previous fork state + +This note compares commit `8725355` (current `work` branch head) against the prior fork state at `c547d9e`. + +## Summary of differences + +* Adds three explainer documents (`overview.md`, `qgis_integration_plan.md`, `ai_playbook.md`) under `documentation/explainers/`. +* No Python package code, tests, or packaging files changed between the two commits. + +## Accuracy and completeness observations + +* The architecture overview correctly names the core entry points (`src/app.py`, `pyromb.Builder`, `Catchment`, `Traveller`) present in the repository. +* The QGIS integration plan notes the absence of Shapely and aligns with the dependency list declared in `pyproject.toml`. +* The AI playbook references testing helpers and pending questions consistent with repository structure. + +## Suggested follow-ups + +1. Keep the new explainers updated alongside future code changes so they remain trustworthy context for contributors. +2. Ensure future Processing-algorithm tasks reconcile with the dependency guidance—avoid adding non-native libraries unless QGIS bundles them. +3. Consider expanding the AI playbook with explicit guidelines for geometry tolerance changes once the QGIS integration amendments begin. + +No regressions were detected in this comparison because only documentation files were added. diff --git a/documentation/explainers/overview.md b/documentation/explainers/overview.md new file mode 100644 index 0000000..558bdeb --- /dev/null +++ b/documentation/explainers/overview.md @@ -0,0 +1,25 @@ +# Pyromb Architecture Overview + +This explainer summarises how the Pyromb library organises its logic so that AI-assisted contributors can quickly locate the right components. + +## High-level package structure + +Pyromb exposes a builder-style API for turning GIS vector layers into hydrologic model control files. The key entry point used by the command-line helper script is `src/app.py`, which orchestrates shapefile ingestion, catchment assembly, and file serialisation. It instantiates vector layer wrappers, builds model components with `pyromb.Builder`, connects them into a `Catchment`, then produces the RORB or WBNM text output via a `Traveller`. The helper also supports optional plotting and serialisation for regression testing. `src/app_testing.py` mirrors this workflow with different defaults for test automation, while `serialise.py` and `plot_catchment.py` contain supporting utilities for debugging and visualisation. + +Inside the Python package (`src/pyromb`), functionality is grouped into four domains: + +* `core`: domain models and algorithms for catchment construction, traversal, and validation. This includes attribute classes (basins, reaches, confluences), geometry abstractions that wrap raw coordinates, and the `Catchment` and `Traveller` classes that connect everything into incidence matrices and model-specific exports. +* `math`: hydrologic calculations (e.g. loss models) that are used during serialisation of the output control files. +* `model`: format-specific builders for RORB and WBNM control vectors, and serialisation helpers that convert the in-memory catchment into strings written by the traveller. +* `resources`: default templates and static assets used by the model writers. + +Supporting modules such as `sf_vector_layer.py` expose a minimal wrapper around OGR vector layers so the builder can consume shapefile geometry consistently. + +## Data flow from GIS layers to control files + +1. **Load vector data** – Shapefiles for reaches, basins, centroids, and confluences are opened via `SFVectorLayer`, which exposes iterable records with geometry and attributes. +2. **Build domain objects** – `pyromb.Builder` reads the vector layers, instantiates geometry primitives (`core.geometry`) and attribute classes, and returns lists of reaches, basins, and confluences. +3. **Assemble the catchment** – A `Catchment` is created from those components. It deduplicates vertices, computes distance-based incidence matrices, and determines upstream/downstream relationships used later by the traveller. +4. **Generate model output** – A `Traveller` walks the connected catchment to populate model-specific writers (`model` package). For RORB this produces a `vector.catg`; for WBNM a `runfile.wbnm`. Optional plotting uses the connected structure to render diagnostic figures. + +Understanding this flow helps target AI-driven changes: geometry or validation fixes generally belong in `core.geometry` or the builder; file-format changes live in `model`; and QGIS integration work happens in the app wrapper and layer adapters. diff --git a/documentation/explainers/qgis_integration_plan.md b/documentation/explainers/qgis_integration_plan.md new file mode 100644 index 0000000..129b059 --- /dev/null +++ b/documentation/explainers/qgis_integration_plan.md @@ -0,0 +1,42 @@ +# QGIS Integration and Dependency Plan + +This note captures the changes required to embed Pyromb as a native QGIS Processing algorithm without introducing non-standard dependencies. + +## Dependency expectations inside QGIS + +* **Python environment** – Recent QGIS releases ship with Python 3.9+ and include GDAL/OGR, PyQt, NumPy, and QGIS core bindings. Third-party packages such as Shapely are not guaranteed to be present and should be avoided unless bundled manually. +* **Current state** – The repository already limits itself to standard library modules, NumPy, and GDAL/OGR via `osgeo`. A quick audit (`rg "shapely"`) shows no references to Shapely in the current codebase, so future contributions should preserve this dependency footprint. + +## Python package installation footprint + +Pyromb distributes as a pure Python package with a minimal dependency tree: + +* **`gdal`** – Declared in `pyproject.toml` so that the GDAL/OGR Python bindings are available for shapefile access when running outside QGIS. Inside QGIS, the bundled GDAL satisfies this requirement. +* **Standard library / QGIS built-ins** – The remaining imports come from Python's standard library, NumPy, and QGIS' own modules that ship with the application. No additional PyPI packages are installed by default. + +When packaging for QGIS Processing, avoid adding new third-party dependencies unless they are guaranteed to be shipped with QGIS or vendored with the plugin. + +## Embedding Pyromb as a Processing provider + +1. **Processing entry point** – Wrap the logic in `src/app.py` inside a `QgsProcessingAlgorithm` subclass. The `processAlgorithm` method should: + * Accept parameter definitions for the required vector layers (reaches, basins, centroids, confluences) and optional toggles (plotting, output path). + * Use `QgsProcessingParameterFeatureSource` to read vector layers directly from the QGIS context, avoiding temporary files. + * Instantiate `pyromb.Builder`, `Catchment`, and `Traveller` the same way the helper script does. +2. **Vector layer abstraction** – Replace `SFVectorLayer` usage with adapters that can consume `QgsFeatureSource` instances. The goal is to reuse existing validation and geometry code, so consider creating an interface that both the shapefile wrapper and a new QGIS wrapper implement. +3. **Output handling** – The algorithm should write to a processing output (`QgsProcessingOutputFile`) and optionally return the generated vector file path to the model catalog. +4. **Logging and feedback** – Use the `feedback` object provided by QGIS Processing for progress messages instead of the standard `logging` module when running inside QGIS. Retain standard logging for standalone runs by introducing a thin abstraction or conditional helper. + +## Geometry logic amendments + +The original plugin required adjustments to geometry handling to work reliably inside QGIS. Further work should focus on: + +* **Tolerance-aware matching** – `Catchment._find_vertex_by_coordinates` currently uses a numeric tolerance when matching nodes. Confirm that the tolerance is appropriate for the coordinate precision in QGIS projects and expose it as a configurable parameter if needed. +* **Projected vs geographic CRS** – Ensure builders operate on projected coordinates (metres) when computing lengths, slopes, or areas. Incorporate CRS validation in the Processing algorithm to warn users when they run the tool in a geographic CRS. +* **Geometry extraction** – When moving from OGR-based layers to `QgsFeature`, use native geometry accessors (`feature.geometry().asPolyline()` etc.) to avoid relying on Shapely conversion helpers. + +## Next steps and open questions + +* Define a minimal adapter interface for vector layers so we can plug in both OGR and QGIS data sources without duplicating builder logic. +* Decide whether plotting should remain part of the Processing algorithm or become a standalone diagnostic tool. Plotting libraries may not be available in all QGIS deployments. +* Review existing serialisation helpers (`serialise.py`) to ensure they are optional in production builds; they are mainly for testing and may not suit a Processing context. +* Document user-facing parameter descriptions and default values for the QGIS tool so plugin development can begin immediately after the geometry updates.