diff --git a/docker-compose.yml b/docker-compose.yml index 6a2cbc85..45312a86 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,8 +20,6 @@ x-flask-defaults: &flask-defaults environment: # Set this variable in .env to start the app with a different config file (default: config.yaml) CONFIG_FILE: - # TODO: This should be removed at some point and the app should be made SQLAlchemy 2.0 compatible! - SQLALCHEMY_SILENCE_UBER_WARNING: 1 OCPDB_POSTGRES_DB: ocpdb OCPDB_POSTGRES_USER: ocpdb OCPDB_POSTGRES_PASSWORD: admin @@ -71,7 +69,7 @@ services: retries: 20 postgre: - image: postgis/postgis + image: postgis/postgis:15-3.5-alpine volumes: - postgres:/var/lib/postgresql/data/ environment: diff --git a/requirements-dev.txt b/requirements-dev.txt index 37cceaa7..984c91ba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ pytest~=8.3.5 -pytest-cov~=6.0.0 +pytest-cov~=6.1.1 requests-mock~=1.12.1 ruff~=0.11.5 diff --git a/requirements.txt b/requirements.txt index 513facc3..4cf19bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,11 @@ gunicorn~=23.0.0 Flask~=3.1.0 Flask-Failsafe~=0.2.0 -# Flask-SQLAlchemy 3.1 requires SQLAlchemy 2.0 -Flask-SQLAlchemy~=3.0.5 +Flask-SQLAlchemy~=3.1.1 Flask-Celery-Helper~=1.1.0 -Flask-Migrate~=4.0.7 +Flask-Migrate~=4.1.0 Flask-CORS~=5.0.1 -SQLAlchemy~=1.4.54 +SQLAlchemy~=2.0.40 requests~=2.32.3 alembic~=1.15.2 lxml~=5.3.2 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 61c1fccc..b2686690 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -16,38 +16,51 @@ along with this program. If not, see . """ +import re +from typing import Generator + import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text +from tests.integration.helpers import empty_all_tables from webapp import launch from webapp.common.flask_app import App from webapp.common.sqlalchemy import SQLAlchemy from webapp.extensions import db as flask_sqlalchemy -@pytest.fixture -def app() -> App: - test_app = launch() - test_app.config.update( - TESTING=True, - DEBUG=True, +@pytest.fixture(scope='session') +def flask_app() -> Generator[App, None, None]: + app = launch( + config_overrides={ + 'TESTING': True, + 'DEBUG': True, + } ) - with test_app.app_context(): - yield test_app - -@pytest.fixture -def db(app: App) -> SQLAlchemy: # Create the database and the database tables - # db_path should be 'mysql+pymysql://root:root@mysql' if - # SQLALCHEMY_DATABASE_URI: 'mysql+pymysql://root:root@mysql/ocpdb' is set in test_config.yaml - db_path: str = app.config.get('SQLALCHEMY_DATABASE_URI')[:-6] + # SQLALCHEMY_DATABASE_URI: 'mysql+pymysql://root:root@mysql/backend?charset=utf8mb4' is set in test_config.yaml + db_path: str = re.sub(r'/[^/]+$', '', app.config.get('SQLALCHEMY_DATABASE_URI')) engine = create_engine(db_path) - connection = engine.connect() - connection.execute('DROP DATABASE IF EXISTS ocpdb;') - connection.execute('CREATE DATABASE IF NOT EXISTS ocpdb;') - flask_sqlalchemy.create_all() + + # We use DROP + CREATE here because it's faster and more reliable in case of foreign keys + with engine.connect() as connection: + connection.execute(text('DROP DATABASE IF EXISTS `post-salad-backend`;')) + connection.execute(text('CREATE DATABASE IF NOT EXISTS `post-salad-backend`;')) + + with app.app_context(): + flask_sqlalchemy.create_all() + + yield app # type: ignore + + +@pytest.fixture +def db(flask_app: App) -> Generator[SQLAlchemy, None, None]: + """ + Yields the database as a function-scoped fixture with freshly emptied tables. + """ + empty_all_tables(db=flask_sqlalchemy) yield flask_sqlalchemy diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py new file mode 100644 index 00000000..1ee6afc8 --- /dev/null +++ b/tests/integration/helpers.py @@ -0,0 +1,34 @@ +""" +Open ChargePoint DataBase OCPDB +Copyright (C) 2025 binary butterfly GmbH + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +""" + +from sqlalchemy import text + +from webapp.common.sqlalchemy import SQLAlchemy + + +def empty_all_tables(db: SQLAlchemy) -> None: + """ + empty all tables in the database + (this is much faster than completely deleting the database and creating a new one) + """ + db.session.close() + with db.engine.connect() as connection: + connection.execute(text('SET FOREIGN_KEY_CHECKS=0;')) + for table_name in db.metadata.tables.keys(): + connection.execute(text(f'TRUNCATE `{table_name}`;')) + connection.execute(text('SET FOREIGN_KEY_CHECKS=1;')) diff --git a/webapp/app.py b/webapp/app.py index c8649ca4..e566383f 100644 --- a/webapp/app.py +++ b/webapp/app.py @@ -34,9 +34,9 @@ __all__ = ['launch'] -def launch() -> App: +def launch(config_overrides: dict | None = None) -> App: app = App(BaseConfig.PROJECT_NAME) - configure_app(app) + configure_app(app, config_overrides) configure_extensions(app) configure_blueprints(app) configure_error_handlers(app) @@ -44,9 +44,9 @@ def launch() -> App: return app -def configure_app(app: App) -> None: +def configure_app(app: App, config_overrides: dict | None = None) -> None: config_loader = ConfigLoader() - config_loader.configure_app(app) + config_loader.configure_app(app, config_overrides) def configure_extensions(app: App) -> None: diff --git a/webapp/common/config/config_loader.py b/webapp/common/config/config_loader.py index 56803aef..080ea449 100644 --- a/webapp/common/config/config_loader.py +++ b/webapp/common/config/config_loader.py @@ -27,7 +27,7 @@ class ConfigLoader: @staticmethod - def configure_app(app: Flask) -> None: + def configure_app(app: Flask, config_overrides: None = None) -> None: """ Initializes the app config with default values and loads the actual config from a YAML file. """ @@ -65,6 +65,9 @@ def configure_app(app: Flask) -> None: for key, server in app.config['REMOTE_SERVERS'].items() } + if config_overrides is not None: + app.config.update(config_overrides) + # Ensure that important config values are set config_check = [key for key in app.config['ENFORCE_CONFIG_VALUES'] if key not in app.config] if len(config_check) > 0: diff --git a/webapp/common/sqlalchemy/__init__.py b/webapp/common/sqlalchemy/__init__.py index 197bdd0e..a45be2d5 100644 --- a/webapp/common/sqlalchemy/__init__.py +++ b/webapp/common/sqlalchemy/__init__.py @@ -18,4 +18,3 @@ from .query import Query from .sqlalchemy import SQLAlchemy -from .typing import Mapped diff --git a/webapp/common/sqlalchemy/typing.py b/webapp/common/sqlalchemy/typing.py deleted file mode 100644 index b859f77c..00000000 --- a/webapp/common/sqlalchemy/typing.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Open ChargePoint DataBase OCPDB -Copyright (C) 2021 binary butterfly GmbH - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU Affero General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU Affero General Public License for more details. - -You should have received a copy of the GNU Affero General Public License -along with this program. If not, see . -""" - -from typing import TypeVar, Union - -from sqlalchemy.orm import QueryableAttribute - -__all__ = [ - 'Mapped', -] - -""" -Define type alias for proper type hinting in models. - -Usage: - some_integer: Mapped[int] = db.Column(db.Integer, nullable=False, ...) - nullable_integer: Mapped[Optional[int]] = db.Column(db.Integer, nullable=True, ...) - some_related_things: Mapped[List[Foo]] = db.relationship('Foo') -""" - -T = TypeVar('T') - -# Note: SQLAlchemy actually comes with a class `sqlalchemy.orm.Mapped` which is exactly for the purpose of type hinting. -# Sadly, this is supported either by PyCharm nor by typeshed (as of 2022-11-08), so we define our own type alias here. -Mapped = Union[QueryableAttribute, T] diff --git a/webapp/dependencies.py b/webapp/dependencies.py index 8d3e9f8f..45347bcd 100644 --- a/webapp/dependencies.py +++ b/webapp/dependencies.py @@ -21,7 +21,7 @@ from butterfly_pubsub.sync import PubSubClient from flask import current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import scoped_session from webapp.common.celery import CeleryHelper from webapp.common.config import ConfigHelper @@ -125,7 +125,7 @@ def get_server_auth_helper(self) -> 'ServerAuthHelper': # Database @cache_dependency - def get_db_session(self) -> Session: + def get_db_session(self) -> scoped_session: # Late import (don't initialize all the extensions unless needed) from webapp.extensions import db diff --git a/webapp/models/base.py b/webapp/models/base.py index d8504f91..ba361b39 100644 --- a/webapp/models/base.py +++ b/webapp/models/base.py @@ -17,31 +17,34 @@ """ from datetime import datetime, timezone -from typing import List, Optional +from sqlalchemy import BigInteger +from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.types import UserDefinedType from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped from webapp.extensions import db -class BaseModel: +class BaseModel(db.Model): + __abstract__ = True __table_args__ = { 'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci', } - id: Mapped[int] = db.Column(db.BigInteger, primary_key=True) - created: Mapped[datetime] = db.Column(UtcDateTime(), nullable=False, default=lambda: datetime.now(tz=timezone.utc)) - modified: Mapped[datetime] = db.Column( + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, nullable=False) + created: Mapped[datetime] = mapped_column( + UtcDateTime(), nullable=False, default=lambda: datetime.now(tz=timezone.utc) + ) + modified: Mapped[datetime] = mapped_column( UtcDateTime(), nullable=False, default=datetime.now(tz=timezone.utc), onupdate=datetime.now(tz=timezone.utc), ) - def to_dict(self, fields: Optional[List[str]] = None, ignore: Optional[List[str]] = None) -> dict: + def to_dict(self, fields: list[str] | None = None, ignore: list[str] | None = None) -> dict: result = {} for field in self.metadata.tables[self.__tablename__].c.keys(): if fields is not None and field not in fields: diff --git a/webapp/models/business.py b/webapp/models/business.py index faf3fc94..19b81653 100644 --- a/webapp/models/business.py +++ b/webapp/models/business.py @@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Optional -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db +from sqlalchemy import BigInteger, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship from .base import BaseModel @@ -27,15 +27,15 @@ from .image import Image -class Business(db.Model, BaseModel): +class Business(BaseModel): __tablename__ = 'business' - logo: Mapped[Optional['Image']] = db.relationship('Image', uselist=False) + logo: Mapped[Optional['Image']] = relationship('Image', uselist=False) - logo_id: Mapped[int | None] = db.Column(db.BigInteger, db.ForeignKey('image.id', use_alter=True), nullable=True) + logo_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey('image.id', use_alter=True), nullable=True) - name: Mapped[str] = db.Column(db.String(255), index=True, nullable=False) - website: Mapped[str | None] = db.Column(db.String(255), nullable=True) + name: Mapped[str] = mapped_column(String(255), index=True, nullable=False) + website: Mapped[str | None] = mapped_column(String(255), nullable=True) def to_dict(self, *args, ignore: list[str] | None = None, **kwargs) -> dict: ignore = ignore or [] diff --git a/webapp/models/connector.py b/webapp/models/connector.py index 4f0fe9b0..2adf6663 100644 --- a/webapp/models/connector.py +++ b/webapp/models/connector.py @@ -21,11 +21,11 @@ from math import sqrt from typing import TYPE_CHECKING +from sqlalchemy import BigInteger, ForeignKey, Integer, String +from sqlalchemy import Enum as SqlalchemyEnum +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db - from .base import BaseModel if TYPE_CHECKING: @@ -106,23 +106,23 @@ class PowerType(Enum): DC = 'DC' -class Connector(db.Model, BaseModel): +class Connector(BaseModel): __tablename__ = 'connector' - evse: Mapped['Evse'] = db.relationship('Evse', back_populates='connectors') - evse_id: Mapped[int] = db.Column(db.BigInteger, db.ForeignKey('evse.id', use_alter=True), nullable=False) + evse: Mapped['Evse'] = relationship('Evse', back_populates='connectors') + evse_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('evse.id', use_alter=True), nullable=False) - uid: Mapped[str] = db.Column(db.String(64), nullable=False, index=True) # OCPI: id - standard: Mapped[ConnectorType | None] = db.Column(db.Enum(ConnectorType), nullable=True) - format: Mapped[ConnectorFormat | None] = db.Column(db.Enum(ConnectorFormat), nullable=True) + uid: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # OCPI: id + standard: Mapped[ConnectorType | None] = mapped_column(SqlalchemyEnum(ConnectorType), nullable=True) + format: Mapped[ConnectorFormat | None] = mapped_column(SqlalchemyEnum(ConnectorFormat), nullable=True) # OCHP: chargePointType, OCPI: power_type - power_type: Mapped[PowerType | None] = db.Column(db.Enum(PowerType), nullable=True) - max_voltage: Mapped[int | None] = db.Column(db.Integer, nullable=True) # OCHP: nominalVoltage, OCPI: max_voltage - max_amperage: Mapped[int | None] = db.Column(db.Integer, nullable=True) # OCPI: max_amperage + power_type: Mapped[PowerType | None] = mapped_column(SqlalchemyEnum(PowerType), nullable=True) + max_voltage: Mapped[int | None] = mapped_column(Integer, nullable=True) # OCHP: nominalVoltage, OCPI: max_voltage + max_amperage: Mapped[int | None] = mapped_column(Integer, nullable=True) # OCPI: max_amperage # OCHP: maximumPower, OCPI: max_electric_power - max_electric_power: Mapped[int | None] = db.Column(db.Integer, nullable=True) - last_updated: Mapped[datetime | None] = db.Column(UtcDateTime(), nullable=True) - terms_and_conditions: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCPI: terms_and_conditions + max_electric_power: Mapped[int | None] = mapped_column(Integer, nullable=True) + last_updated: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) + terms_and_conditions: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCPI: terms_and_conditions # tariff_ids TODO diff --git a/webapp/models/evse.py b/webapp/models/evse.py index b5be63e9..a7d831fc 100644 --- a/webapp/models/evse.py +++ b/webapp/models/evse.py @@ -22,14 +22,16 @@ from enum import Enum from typing import TYPE_CHECKING +from sqlalchemy import BigInteger, Float, ForeignKey, Integer, Numeric, String, Text +from sqlalchemy import Enum as SqlalchemyEnum from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy_utc import UtcDateTime from webapp.common.json import DefaultJSONEncoder -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db from .base import BaseModel +from .evse_image import EvseImageAssociation if TYPE_CHECKING: from .connector import Connector @@ -81,59 +83,56 @@ class Capability(Enum): DIRECT_REMOTE = 'DIRECT_REMOTE' -evse_image = db.Table( - 'evse_image', - db.Column('evse_id', db.BigInteger, db.ForeignKey('evse.id', use_alter=True), nullable=False), - db.Column('image_id', db.BigInteger, db.ForeignKey('image.id', use_alter=True), nullable=False), -) - - -class Evse(db.Model, BaseModel): +class Evse(BaseModel): __tablename__ = 'evse' - connectors: Mapped[list['Connector']] = db.relationship( + connectors: Mapped[list['Connector']] = relationship( 'Connector', back_populates='evse', cascade='all, delete, delete-orphan', ) - images: Mapped[list['Image']] = db.relationship('Image', secondary=evse_image) - related_resources: Mapped['RelatedResource'] = db.relationship( + images: Mapped[list['Image']] = relationship( + 'Image', + secondary=EvseImageAssociation.__table__, + back_populates='evses', + ) + related_resources: Mapped['RelatedResource'] = relationship( 'RelatedResource', back_populates='evse', cascade='all, delete, delete-orphan', ) - location: Mapped['Location'] = db.relationship('Location', back_populates='evses') + location: Mapped['Location'] = relationship('Location', back_populates='evses') - location_id: Mapped[int] = db.Column(db.BigInteger, db.ForeignKey('location.id', use_alter=True), nullable=False) + location_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('location.id', use_alter=True), nullable=False) - uid: Mapped[str] = db.Column(db.String(64), nullable=False, index=True) - evse_id: Mapped[str | None] = db.Column(db.String(64), nullable=True, index=True) - status: Mapped[EvseStatus] = db.Column( - db.Enum(EvseStatus, name='EvseStatus'), + uid: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + evse_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) + status: Mapped[EvseStatus] = mapped_column( + SqlalchemyEnum(EvseStatus, name='EvseStatus'), default=EvseStatus.UNKNOWN, nullable=False, ) - lat: Mapped[Decimal | None] = db.Column(db.Numeric(9, 7), nullable=True) - lon: Mapped[Decimal | None] = db.Column(db.Numeric(10, 7), nullable=True) + lat: Mapped[Decimal | None] = mapped_column(Numeric(9, 7), nullable=True) + lon: Mapped[Decimal | None] = mapped_column(Numeric(10, 7), nullable=True) - floor_level: Mapped[str | None] = db.Column(db.String(16), nullable=True) - physical_reference: Mapped[str | None] = db.Column(db.String(255), nullable=True) - _directions: Mapped[str | None] = db.Column('directions', db.Text, nullable=True) - phone: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCHP: telephoneNumber + floor_level: Mapped[str | None] = mapped_column(String(16), nullable=True) + physical_reference: Mapped[str | None] = mapped_column(String(255), nullable=True) + _directions: Mapped[str | None] = mapped_column('directions', Text, nullable=True) + phone: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: telephoneNumber - parking_uid: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCHP: parkingSpot.parkingId - parking_floor_level: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCHP: parkingSpot.floorlevel + parking_uid: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: parkingSpot.parkingId + parking_floor_level: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: parkingSpot.floorlevel # OCHP: parkingSpot.parkingSpotNumber - parking_spot_number: Mapped[str | None] = db.Column(db.String(255), nullable=True) + parking_spot_number: Mapped[str | None] = mapped_column(String(255), nullable=True) - last_updated: Mapped[datetime | None] = db.Column(UtcDateTime(), nullable=True) - max_reservation: Mapped[float | None] = db.Column(db.Float, nullable=True) # OCHP maxReservation - _capabilities: Mapped[int | None] = db.Column('capabilities', db.Integer, nullable=True) # OCPI: capability + last_updated: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) + max_reservation: Mapped[float | None] = mapped_column(Float, nullable=True) # OCHP maxReservation + _capabilities: Mapped[int | None] = mapped_column('capabilities', Integer, nullable=True) # OCPI: capability # OCHP: RestrictionType OCPI: parking_restrictions - _parking_restrictions: Mapped[int | None] = db.Column('parking_restrictions', db.Integer, nullable=True) + _parking_restrictions: Mapped[int | None] = mapped_column('parking_restrictions', Integer, nullable=True) - terms_and_conditions: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCPI: terms_and_conditions + terms_and_conditions: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCPI: terms_and_conditions # status_schedule TODO # user_interface_lang TODO # OCHP userInterfaceLang diff --git a/webapp/models/evse_image.py b/webapp/models/evse_image.py new file mode 100644 index 00000000..989c4cf2 --- /dev/null +++ b/webapp/models/evse_image.py @@ -0,0 +1,39 @@ +""" +Open ChargePoint DataBase OCPDB +Copyright (C) 2025 binary butterfly GmbH + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +""" + +from sqlalchemy import BigInteger, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from webapp.extensions import db + + +class EvseImageAssociation(db.Model): + __tablename__ = 'evse_image' + + evse_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey('evse.id'), + primary_key=True, + nullable=False, + ) + image_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey('image.id'), + primary_key=True, + nullable=False, + ) diff --git a/webapp/models/exceptional_closing_period.py b/webapp/models/exceptional_closing_period.py index 49970bf0..abd2535f 100644 --- a/webapp/models/exceptional_closing_period.py +++ b/webapp/models/exceptional_closing_period.py @@ -19,24 +19,24 @@ from datetime import datetime from typing import TYPE_CHECKING +from sqlalchemy import BigInteger, ForeignKey +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db - from .base import BaseModel +from .connector import mapped_column if TYPE_CHECKING: from .location import Location -class ExceptionalClosingPeriod(db.Model, BaseModel): +class ExceptionalClosingPeriod(BaseModel): __tablename__ = 'exceptional_closing_period' - location: Mapped['Location'] = db.relationship('Location', back_populates='exceptional_closings') - location_id: Mapped[int] = db.Column(db.BigInteger, db.ForeignKey('location.id', use_alter=True), nullable=False) - period_begin: Mapped[datetime] = db.Column(UtcDateTime(), nullable=False) - period_end: Mapped[datetime] = db.Column(UtcDateTime(), nullable=False) + location: Mapped['Location'] = relationship('Location', back_populates='exceptional_closings') + location_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('location.id', use_alter=True), nullable=False) + period_begin: Mapped[datetime] = mapped_column(UtcDateTime(), nullable=False) + period_end: Mapped[datetime] = mapped_column(UtcDateTime(), nullable=False) def to_dict(self, *args, ignore: list[str] | None = None, **kwargs) -> dict: ignore = ignore or [] diff --git a/webapp/models/exceptional_opening_period.py b/webapp/models/exceptional_opening_period.py index ac4301a4..d52bf8e8 100644 --- a/webapp/models/exceptional_opening_period.py +++ b/webapp/models/exceptional_opening_period.py @@ -19,24 +19,24 @@ from datetime import datetime from typing import TYPE_CHECKING +from sqlalchemy import BigInteger, ForeignKey +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db - from .base import BaseModel +from .connector import mapped_column if TYPE_CHECKING: from .location import Location -class ExceptionalOpeningPeriod(db.Model, BaseModel): +class ExceptionalOpeningPeriod(BaseModel): __tablename__ = 'exceptional_opening_period' - location: Mapped['Location'] = db.relationship('Location', back_populates='exceptional_openings') - location_id: Mapped[int] = db.Column(db.BigInteger, db.ForeignKey('location.id', use_alter=True), nullable=False) - period_begin: Mapped[datetime] = db.Column(UtcDateTime(), nullable=False) - period_end: Mapped[datetime] = db.Column(UtcDateTime(), nullable=False) + location: Mapped['Location'] = relationship('Location', back_populates='exceptional_openings') + location_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('location.id', use_alter=True), nullable=False) + period_begin: Mapped[datetime] = mapped_column(UtcDateTime(), nullable=False) + period_end: Mapped[datetime] = mapped_column(UtcDateTime(), nullable=False) def to_dict(self, *args, ignore: list[str] | None = None, **kwargs) -> dict: ignore = ignore or [] diff --git a/webapp/models/image.py b/webapp/models/image.py index db5ce173..605c6069 100644 --- a/webapp/models/image.py +++ b/webapp/models/image.py @@ -19,14 +19,21 @@ from datetime import datetime from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING from flask import current_app +from sqlalchemy import Enum as SqlalchemyEnum +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db - from .base import BaseModel +from .evse_image import EvseImageAssociation +from .location_image import LocationImageAssociation + +if TYPE_CHECKING: + from .evse import Evse + from .location import Location class ImageCategory(Enum): @@ -39,15 +46,26 @@ class ImageCategory(Enum): OWNER = 'OWNER' -class Image(db.Model, BaseModel): +class Image(BaseModel): __tablename__ = 'image' - external_url: Mapped[str | None] = db.Column(db.String(255), index=True, nullable=True) - type: Mapped[str | None] = db.Column(db.String(4), nullable=True) - category: Mapped[ImageCategory | None] = db.Column(db.Enum(ImageCategory), nullable=True) - width: Mapped[int | None] = db.Column(db.Integer, nullable=True) - height: Mapped[int | None] = db.Column(db.Integer, nullable=True) - last_download: Mapped[datetime | None] = db.Column(UtcDateTime(timezone=True), nullable=True) + evses: Mapped[list['Evse']] = relationship( + 'Evse', + secondary=EvseImageAssociation.__table__, + back_populates='images', + ) + locations: Mapped[list['Location']] = relationship( + 'Location', + secondary=LocationImageAssociation.__table__, + back_populates='images', + ) + + external_url: Mapped[str | None] = mapped_column(String(255), index=True, nullable=True) + type: Mapped[str | None] = mapped_column(String(4), nullable=True) + category: Mapped[ImageCategory | None] = mapped_column(SqlalchemyEnum(ImageCategory), nullable=True) + width: Mapped[int | None] = mapped_column(Integer, nullable=True) + height: Mapped[int | None] = mapped_column(Integer, nullable=True) + last_download: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) @property def url(self) -> str: @@ -62,7 +80,7 @@ def url_thumbnail(self) -> str: return f'{current_app.config["PROJECT_URL"]}/static/images/dynamic/{self.id}.thumb.{self.type}' @property - def path_thumbnail(self): + def path_thumbnail(self) -> Path: return Path(current_app.config['DYNAMIC_IMAGE_DIR'], f'{self.id}.thumb.{self.type}') def to_dict(self, *args, strict: bool = False, ignore: list[str] | None = None, **kwargs) -> dict: @@ -78,12 +96,3 @@ def to_dict(self, *args, strict: bool = False, ignore: list[str] | None = None, result['last_download'] = self.last_download return result - - -# TODO: use this for UPDATE checks -# @event.listens_for(Image, 'before_insert') -# @event.listens_for(Image, 'before_update') -# def set_geometry(mapper, connection, image): -# state = db.inspect(image) -# for attr in state.attrs: -# print(state.get_history(attr.key, True)) diff --git a/webapp/models/location.py b/webapp/models/location.py index 79a49f99..7a7709d5 100644 --- a/webapp/models/location.py +++ b/webapp/models/location.py @@ -22,15 +22,30 @@ from enum import Enum from typing import TYPE_CHECKING, Optional -from sqlalchemy import Index, event, func +from sqlalchemy import ( + BigInteger, + Boolean, + Float, + ForeignKey, + Index, + Numeric, + String, + Text, + event, + func, +) +from sqlalchemy import ( + Enum as SqlalchemyEnum, +) from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy_utc import UtcDateTime from webapp.common.json import DefaultJSONEncoder -from webapp.common.sqlalchemy import Mapped from webapp.extensions import db from .base import BaseModel, Point +from .location_image import LocationImageAssociation if TYPE_CHECKING: from .business import Business @@ -98,87 +113,88 @@ class TokenType(Enum): RFID = 'RFID' -location_image = db.Table( - 'location_image', - db.Column('location_id', db.BigInteger, db.ForeignKey('location.id', use_alter=True), nullable=False), - db.Column('image_id', db.BigInteger, db.ForeignKey('image.id', use_alter=True), nullable=False), -) - - -class Location(db.Model, BaseModel): +class Location(BaseModel): __tablename__ = 'location' __table_args__ = (Index('uid', 'source'),) - evses: Mapped[list['Evse']] = db.relationship( + evses: Mapped[list['Evse']] = relationship( 'Evse', back_populates='location', cascade='all, delete, delete-orphan', ) - images: Mapped[list['Image']] = db.relationship('Image', secondary=location_image) + images: Mapped[list['Image']] = relationship( + 'Image', + secondary=LocationImageAssociation.__table__, + back_populates='locations', + ) - exceptional_openings: Mapped[list['ExceptionalOpeningPeriod']] = db.relationship( + exceptional_openings: Mapped[list['ExceptionalOpeningPeriod']] = relationship( 'ExceptionalOpeningPeriod', back_populates='location', cascade='all, delete, delete-orphan', ) - exceptional_closings: Mapped[list['ExceptionalClosingPeriod']] = db.relationship( + exceptional_closings: Mapped[list['ExceptionalClosingPeriod']] = relationship( 'ExceptionalClosingPeriod', back_populates='location', cascade='all, delete, delete-orphan', ) - regular_hours: Mapped[list['RegularHours']] = db.relationship( + regular_hours: Mapped[list['RegularHours']] = relationship( 'RegularHours', back_populates='location', cascade='all, delete, delete-orphan', ) - operator_id: Mapped[int | None] = db.Column( - db.BigInteger, - db.ForeignKey('business.id', use_alter=True), + operator_id: Mapped[int | None] = mapped_column( + BigInteger, + ForeignKey('business.id', use_alter=True), + nullable=True, + ) + suboperator_id: Mapped[int | None] = mapped_column( + BigInteger, + ForeignKey('business.id', use_alter=True), nullable=True, ) - suboperator_id: Mapped[int | None] = db.Column( - db.BigInteger, - db.ForeignKey('business.id', use_alter=True), + owner_id: Mapped[int | None] = mapped_column( + BigInteger, + ForeignKey('business.id', use_alter=True), nullable=True, ) - owner_id: Mapped[int | None] = db.Column(db.BigInteger, db.ForeignKey('business.id', use_alter=True), nullable=True) - operator: Mapped[Optional['Business']] = db.relationship('Business', foreign_keys=[operator_id]) - suboperator: Mapped[Optional['Business']] = db.relationship('Business', foreign_keys=[suboperator_id]) - owner: Mapped[Optional['Business']] = db.relationship('Business', foreign_keys=[owner_id]) + operator: Mapped[Optional['Business']] = relationship('Business', foreign_keys=[operator_id]) + suboperator: Mapped[Optional['Business']] = relationship('Business', foreign_keys=[suboperator_id]) + owner: Mapped[Optional['Business']] = relationship('Business', foreign_keys=[owner_id]) - uid: Mapped[str] = db.Column(db.String(255), index=True, nullable=False) # OCHP: locationId OCPI: id - source: Mapped[str] = db.Column(db.String(64), index=True, nullable=False) + uid: Mapped[str] = mapped_column(String(255), index=True, nullable=False) # OCHP: locationId OCPI: id + source: Mapped[str] = mapped_column(String(64), index=True, nullable=False) - dynamic_location_id: Mapped[int] = db.Column(db.BigInteger) # TODO: relation? - dynamic_location_probability: Mapped[int] = db.Column(db.Float) + dynamic_location_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + dynamic_location_probability: Mapped[float | None] = mapped_column(Float, nullable=True) - name: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCHP: locationName, OCPI: name + name: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: locationName, OCPI: name # OCHP: chargePointAddress.address, OCPI: address - address: Mapped[str | None] = db.Column(db.String(255), nullable=True) + address: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: chargePointAddress.zipCode, OCPI: postal_code - postal_code: Mapped[str | None] = db.Column(db.String(255), nullable=True) - city: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCHP: chargePointAddress.city, OCPI: city - state: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCPI: state + postal_code: Mapped[str | None] = mapped_column(String(255), nullable=True) + city: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCHP: chargePointAddress.city, OCPI: city + state: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCPI: state # OCHP: chargePointAddress.country, OCPI: country - country: Mapped[str | None] = db.Column(db.String(3), nullable=True) + country: Mapped[str | None] = mapped_column(String(3), nullable=True) # OCHP: chargePointLocation.lat, OCPI: coordinates.latitude - lat: Mapped[Decimal | None] = db.Column(db.Numeric(9, 7), nullable=True) + lat: Mapped[Decimal | None] = mapped_column(Numeric(9, 7), nullable=True) # OCHP: chargePointLocation.lon, OCPI: coordinates.longitude - lon: Mapped[Decimal | None] = db.Column(db.Numeric(10, 7), nullable=True) + lon: Mapped[Decimal | None] = mapped_column(Numeric(10, 7), nullable=True) - _directions: Mapped[str | None] = db.Column('directions', db.Text, nullable=True) # OCPI: directions - parking_type: Mapped[ParkingType | None] = db.Column(db.Enum(ParkingType), nullable=True) - time_zone: Mapped[str | None] = db.Column(db.String(32), nullable=True) # OCHP: timeZone, OCPI: time_zone + _directions: Mapped[str | None] = mapped_column('directions', Text, nullable=True) # OCPI: directions + parking_type: Mapped[ParkingType | None] = mapped_column(SqlalchemyEnum(ParkingType), nullable=True) + time_zone: Mapped[str | None] = mapped_column(String(32), nullable=True) # OCHP: timeZone, OCPI: time_zone - last_updated: Mapped[datetime | None] = db.Column(UtcDateTime(), nullable=True) # OCHP: timestamp + last_updated: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) # OCHP: timestamp - terms_and_conditions: Mapped[str | None] = db.Column(db.String(255), nullable=True) # OCPI: terms_and_conditions + terms_and_conditions: Mapped[str | None] = mapped_column(String(255), nullable=True) # OCPI: terms_and_conditions # OCHP: openingTimes.twentyfourseven OCPI: opening_times.twentyfourseven - twentyfourseven: Mapped[bool | None] = db.Column(db.Boolean, nullable=True) + twentyfourseven: Mapped[bool | None] = mapped_column(Boolean, nullable=True) - geometry: Mapped[Point] = db.Column(Point(), nullable=False) + geometry: Mapped[Point] = mapped_column(Point(), nullable=False) @hybrid_property def directions(self) -> list[dict[str, str]] | None: diff --git a/webapp/models/location_image.py b/webapp/models/location_image.py new file mode 100644 index 00000000..64598ee9 --- /dev/null +++ b/webapp/models/location_image.py @@ -0,0 +1,39 @@ +""" +Open ChargePoint DataBase OCPDB +Copyright (C) 2025 binary butterfly GmbH + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +""" + +from sqlalchemy import BigInteger, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from webapp.extensions import db + + +class LocationImageAssociation(db.Model): + __tablename__ = 'location_image' + + location_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey('location.id'), + primary_key=True, + nullable=False, + ) + image_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey('image.id'), + primary_key=True, + nullable=False, + ) diff --git a/webapp/models/option.py b/webapp/models/option.py index d5e984ea..06a041d8 100644 --- a/webapp/models/option.py +++ b/webapp/models/option.py @@ -16,18 +16,19 @@ along with this program. If not, see . """ -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db +from sqlalchemy import Enum as SqlalchemyEnum +from sqlalchemy import String, Text +from sqlalchemy.orm import Mapped, mapped_column from .base import BaseModel -class Option(db.Model, BaseModel): +class Option(BaseModel): __tablename__ = 'option' - key: Mapped[str | None] = db.Column(db.String(128), index=True, nullable=True) - type: Mapped[str | None] = db.Column( - db.Enum('string', 'date', 'datetime', 'integer', 'decimal', 'dict', 'list', name='OptionType'), + key: Mapped[str | None] = mapped_column(String(128), index=True, nullable=True) + type: Mapped[str | None] = mapped_column( + SqlalchemyEnum('string', 'date', 'datetime', 'integer', 'decimal', 'dict', 'list', name='OptionType'), nullable=True, ) - value: Mapped[str | None] = db.Column(db.Text, nullable=True) + value: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/webapp/models/regular_hours.py b/webapp/models/regular_hours.py index 5002c4ef..d7f002e4 100644 --- a/webapp/models/regular_hours.py +++ b/webapp/models/regular_hours.py @@ -18,8 +18,8 @@ from typing import TYPE_CHECKING -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db +from sqlalchemy import BigInteger, ForeignKey, Integer, SmallInteger +from sqlalchemy.orm import Mapped, mapped_column, relationship from .base import BaseModel @@ -27,15 +27,15 @@ from .location import Location -class RegularHours(db.Model, BaseModel): +class RegularHours(BaseModel): __tablename__ = 'regular_hours' - location: Mapped['Location'] = db.relationship('Location', back_populates='regular_hours') - location_id = db.Column(db.BigInteger, db.ForeignKey('location.id', use_alter=True), nullable=False) + location: Mapped['Location'] = relationship('Location', back_populates='regular_hours') + location_id = mapped_column(BigInteger, ForeignKey('location.id', use_alter=True), nullable=False) - weekday: Mapped[int] = db.Column(db.SmallInteger, nullable=False) - period_begin: Mapped[int] = db.Column(db.Integer, nullable=False) - period_end: Mapped[int] = db.Column(db.Integer, nullable=False) + weekday: Mapped[int] = mapped_column(SmallInteger, nullable=False) + period_begin: Mapped[int] = mapped_column(Integer, nullable=False) + period_end: Mapped[int] = mapped_column(Integer, nullable=False) def to_dict(self, *args, ignore: list[str] | None = None, **kwargs) -> dict: ignore = ignore or [] diff --git a/webapp/models/related_resource.py b/webapp/models/related_resource.py index 6a40a0ec..3bc81cd2 100644 --- a/webapp/models/related_resource.py +++ b/webapp/models/related_resource.py @@ -17,12 +17,11 @@ """ from enum import Enum -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING +from sqlalchemy import BigInteger, ForeignKey, Integer, String from sqlalchemy.ext.hybrid import hybrid_property - -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db +from sqlalchemy.orm import Mapped, mapped_column, relationship from .base import BaseModel @@ -40,17 +39,17 @@ class RelatedResourceType(Enum): OPENING_TIMES = 'OPENING_TIMES' -class RelatedResource(db.Model, BaseModel): +class RelatedResource(BaseModel): __tablename__ = 'related_resource' - evse: Mapped['Evse'] = db.relationship('Evse', back_populates='related_resources') - evse_id: Mapped[int] = db.Column(db.BigInteger, db.ForeignKey('evse.id', use_alter=True), nullable=False) + evse: Mapped['Evse'] = relationship('Evse', back_populates='related_resources') + evse_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('evse.id', use_alter=True), nullable=False) - url: Mapped[str | None] = db.Column(db.String(255), nullable=True) - _types: Mapped[int | None] = db.Column('types', db.Integer, nullable=True) + url: Mapped[str | None] = mapped_column(String(255), nullable=True) + _types: Mapped[int | None] = mapped_column('types', Integer, nullable=True) @hybrid_property - def types(self) -> List[RelatedResourceType]: + def types(self) -> list[RelatedResourceType]: if not self._types: return [] return sorted( @@ -58,7 +57,7 @@ def types(self) -> List[RelatedResourceType]: ) @types.setter - def types(self, types: List[RelatedResourceType]) -> None: + def types(self, types: list[RelatedResourceType]) -> None: self._types = 0 for _type in types: self._types = self._types | _type.value diff --git a/webapp/models/source.py b/webapp/models/source.py index 98c2a9a2..9907409b 100644 --- a/webapp/models/source.py +++ b/webapp/models/source.py @@ -19,11 +19,11 @@ from datetime import datetime from enum import Enum +from sqlalchemy import Enum as SqlalchemyEnum +from sqlalchemy import Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy_utc import UtcDateTime -from webapp.common.sqlalchemy import Mapped -from webapp.extensions import db - from .base import BaseModel @@ -34,33 +34,33 @@ class SourceStatus(Enum): PROVISIONED = 'PROVISIONED' -class Source(db.Model, BaseModel): +class Source(BaseModel): __tablename__ = 'source' - uid: Mapped[str] = db.Column(db.String(256), nullable=False, index=True, unique=True) - name: Mapped[str | None] = db.Column(db.String(256), nullable=True) - public_url: Mapped[str | None] = db.Column(db.String(4096), nullable=True) + uid: Mapped[str] = mapped_column(String(256), nullable=False, index=True, unique=True) + name: Mapped[str | None] = mapped_column(String(256), nullable=True) + public_url: Mapped[str | None] = mapped_column(String(4096), nullable=True) - static_data_updated_at: Mapped[datetime | None] = db.Column(UtcDateTime(), nullable=True) - realtime_data_updated_at: Mapped[datetime | None] = db.Column(UtcDateTime(), nullable=True) + static_data_updated_at: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) + realtime_data_updated_at: Mapped[datetime | None] = mapped_column(UtcDateTime(), nullable=True) - attribution_license: Mapped[str | None] = db.Column(db.Text(), nullable=True) - attribution_contributor: Mapped[str | None] = db.Column(db.String(256), nullable=True) - attribution_url: Mapped[str | None] = db.Column(db.String(256), nullable=True) + attribution_license: Mapped[str | None] = mapped_column(Text, nullable=True) + attribution_contributor: Mapped[str | None] = mapped_column(String(256), nullable=True) + attribution_url: Mapped[str | None] = mapped_column(String(256), nullable=True) - static_status: Mapped[SourceStatus] = db.Column( - db.Enum(SourceStatus), + static_status: Mapped[SourceStatus] = mapped_column( + SqlalchemyEnum(SourceStatus), nullable=False, default=SourceStatus.PROVISIONED, ) - realtime_status: Mapped[SourceStatus] = db.Column( - db.Enum(SourceStatus), + realtime_status: Mapped[SourceStatus] = mapped_column( + SqlalchemyEnum(SourceStatus), nullable=False, default=SourceStatus.PROVISIONED, ) - static_error_count: Mapped[int] = db.Column(db.Integer(), nullable=False, default=0) - realtime_error_count: Mapped[int] = db.Column(db.Integer(), nullable=False, default=0) + static_error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + realtime_error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) @property def combined_status(self) -> SourceStatus: diff --git a/webapp/repositories/__init__.py b/webapp/repositories/__init__.py index 1d9391b2..36c1efbf 100644 --- a/webapp/repositories/__init__.py +++ b/webapp/repositories/__init__.py @@ -16,10 +16,10 @@ along with this program. If not, see . """ -from .base_repository import ObjectNotFoundException from .business_repository import BusinessRepository from .connector_repository import ConnectorRepository from .evse_repository import EvseRepository +from .exceptions import ObjectNotFoundException from .image_repository import ImageRepository from .location_repository import LocationRepository from .option_repository import OptionRepository diff --git a/webapp/repositories/base_repository.py b/webapp/repositories/base_repository.py index 5ed292be..22b1c220 100644 --- a/webapp/repositories/base_repository.py +++ b/webapp/repositories/base_repository.py @@ -19,30 +19,15 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Optional, Type, TypeVar -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, scoped_session from validataclass_search_queries.repositories import SearchQueryRepositoryMixin -from webapp.common.error_handling.exceptions import AppException -from webapp.extensions import db from webapp.models.base import BaseModel +from webapp.repositories.exceptions import ObjectNotFoundException T_Model = TypeVar('T_Model', bound=BaseModel) -class InconsistentDataException(AppException): - code = 'inconsistent_data' - - -class ObjectNotFoundException(AppException): - """ - The requested object was not found or is out of scope. - This exception may be extended (e.g. UserNotFoundException) for specific object types if needed. - """ - - code = 'not_found' - http_status = 404 - - class BaseRepository(SearchQueryRepositoryMixin[T_Model], Generic[T_Model], ABC): @property @abstractmethod @@ -51,12 +36,37 @@ def model_cls(self) -> Type[T_Model]: session: Session - def __init__(self, session: Optional[Session] = None) -> None: - self.session = db.session if session is None else session + def __init__(self, session: scoped_session) -> None: + self.session = session def exists(self, obj, field, value): return self.session.query(obj).filter(**{field: value}).count() > 0 + def fetch_resource_by_id( + self, + resource_id: int, + *, + load_options: list | None = None, + resource_name: str | None = None, + ) -> T_Model: + """ + Fetch a resource by its ID. + Raises ObjectNotFoundException if the resource does not exist or is out of scope. + """ + load_options = load_options or [] + + resource = ( + self.session.query(self.model_cls) + .options(*load_options) + .filter(self.model_cls.id == resource_id) + .one_or_none() + ) + + return self._or_raise( + resource, + f'{resource_name or self.model_cls.__name__} with ID {resource_id} was not found.', + ) + @staticmethod def _or_raise( resource: Optional[Any], diff --git a/webapp/repositories/business_repository.py b/webapp/repositories/business_repository.py index bba7d09e..e66e07ae 100644 --- a/webapp/repositories/business_repository.py +++ b/webapp/repositories/business_repository.py @@ -23,19 +23,14 @@ from webapp.models import Business -from .base_repository import BaseRepository, ObjectNotFoundException +from .base_repository import BaseRepository class BusinessRepository(BaseRepository[Business]): model_cls = Business def fetch_by_id(self, business_id: int) -> Business: - result = self.session.query(Business).get(business_id) - - if result is None: - raise ObjectNotFoundException(f'business with id {business_id} not found') - - return result + return self.fetch_resource_by_id(business_id) def fetch_businesses(self, search_query: Optional[BaseSearchQuery] = None) -> PaginatedResult[Business]: query = self.session.query(Business) @@ -44,7 +39,4 @@ def fetch_businesses(self, search_query: Optional[BaseSearchQuery] = None) -> Pa def fetch_business_by_name(self, name: str) -> Business: result = self.session.query(Business).filter(Business.name == name).first() - if result is None: - raise ObjectNotFoundException(f'business with name {name} not found') - - return result + return self._or_raise(result, f'business with name {name} not found') diff --git a/webapp/repositories/connector_repository.py b/webapp/repositories/connector_repository.py index 239036b9..7c291b48 100644 --- a/webapp/repositories/connector_repository.py +++ b/webapp/repositories/connector_repository.py @@ -16,25 +16,18 @@ along with this program. If not, see . """ -from typing import List - from webapp.models import Connector, Evse, Location -from .base_repository import BaseRepository, ObjectNotFoundException +from .base_repository import BaseRepository class ConnectorRepository(BaseRepository[Connector]): model_cls = Connector def fetch_by_id(self, connector_id: int) -> Connector: - result = self.session.query(Connector).get(connector_id) - - if result is None: - raise ObjectNotFoundException(f'connector with id {connector_id} not found') - - return result + return self.fetch_resource_by_id(connector_id) - def fetch_connectors_by_ids(self, connector_ids: List[int]) -> List[Connector]: + def fetch_connectors_by_ids(self, connector_ids: list[int]) -> list[Connector]: return self.session.query(Connector).filter(Connector.id.in_(connector_ids)).all() def fetch_by_uid(self, source: str, connector_uid: str) -> Connector: @@ -47,13 +40,4 @@ def fetch_by_uid(self, source: str, connector_uid: str) -> Connector: .first() ) - if result is None: - raise ObjectNotFoundException(message=f'connector with uid {connector_uid} and source {source} not found') - - return result - - def delete_connector_by_id(self, connector_ids: List[int]): - self.session.query(Connector).filter(Connector.id.in_(connector_ids)).delete(synchronize_session=False) - - def delete_connector_by_ids(self, connector_id: int): - self.session.query(Connector).filter(Connector.id == connector_id).delete(synchronize_session=False) + return self._or_raise(result, f'connector with uid {connector_uid} and source {source} not found') diff --git a/webapp/repositories/evse_repository.py b/webapp/repositories/evse_repository.py index a7342aad..60ab321b 100644 --- a/webapp/repositories/evse_repository.py +++ b/webapp/repositories/evse_repository.py @@ -17,12 +17,12 @@ """ from dataclasses import dataclass -from typing import List, Tuple from webapp.models import Evse, Location from webapp.models.evse import EvseStatus -from .base_repository import BaseRepository, InconsistentDataException, ObjectNotFoundException +from .base_repository import BaseRepository +from .exceptions import InconsistentDataException, ObjectNotFoundException @dataclass @@ -61,38 +61,30 @@ def fetch_by_uid(self, source: str, uid: str) -> Evse: return items[0] - def fetch_evse_by_location_id(self, location_id: int) -> List[Evse]: - return self.session.query(Evse).filter(Evse.location_id == location_id) + def fetch_evse_by_location_id(self, location_id: int) -> list[Evse]: + return self.session.query(Evse).filter(Evse.location_id == location_id).all() - def fetch_evse_uids(self) -> List[str]: + def fetch_evse_uids(self) -> list[str]: items = self.session.query(Evse.uid).all() return [item.uid for item in items] - def fetch_extended_evse_uids(self) -> List[Tuple[str, int]]: + def fetch_extended_evse_uids(self) -> list[tuple[str, int]]: items = self.session.query(Evse.uid, Evse.location_id).all() return [(item.uid, item.location_id) for item in items] def save_evse(self, evse: Evse, *, commit: bool = True): - self.session.add(evse) - if commit: - self.session.commit() + self._save_resources(evse, commit=commit) - def delete_evse_by_ids(self, evse_ids: List[int]): - self.session.query(Evse).filter(Evse.id.in_(evse_ids)).delete(synchronize_session=False) - - def delete_evse_by_id(self, evse_id: int): - self.session.query(Evse).filter(id=evse_id).delete(synchronize_session=False) - - def fetch_evse_status_summary(self) -> List[EvseStatusSummary]: + def fetch_evse_status_summary(self) -> list[EvseStatusSummary]: items = ( self.session.query(Evse.uid.label('evse'), Evse.status, Location.uid.label('location'), Location.source) .filter(Evse.status != EvseStatus.STATIC) .join(Evse.location) .all() ) - result: List[EvseStatusSummary] = [] + result: list[EvseStatusSummary] = [] for item in items: result.append(EvseStatusSummary(**dict(item))) diff --git a/tests/integration/test_app_start.py b/webapp/repositories/exceptions.py similarity index 59% rename from tests/integration/test_app_start.py rename to webapp/repositories/exceptions.py index a19aeb18..ee649463 100644 --- a/tests/integration/test_app_start.py +++ b/webapp/repositories/exceptions.py @@ -1,6 +1,6 @@ """ Open ChargePoint DataBase OCPDB -Copyright (C) 2021 binary butterfly GmbH +Copyright (C) 2025 binary butterfly GmbH This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by @@ -16,10 +16,18 @@ along with this program. If not, see . """ -from webapp.common.flask_app import App +from webapp.common.error_handling.exceptions import AppException -def test_start_app(app: App): - with app.test_client() as client: - response = client.get('/api/public/v1/businesses/1') - assert response +class InconsistentDataException(AppException): + code = 'inconsistent_data' + + +class ObjectNotFoundException(AppException): + """ + The requested object was not found or is out of scope. + This exception may be extended (e.g. UserNotFoundException) for specific object types if needed. + """ + + code = 'not_found' + http_status = 404 diff --git a/webapp/repositories/image_repository.py b/webapp/repositories/image_repository.py index 230f21b3..bc81dc44 100644 --- a/webapp/repositories/image_repository.py +++ b/webapp/repositories/image_repository.py @@ -20,19 +20,14 @@ from webapp.models import Image -from .base_repository import BaseRepository, ObjectNotFoundException +from .base_repository import BaseRepository class ImageRepository(BaseRepository[Image]): model_cls = Image def fetch_image_by_id(self, image_id: int) -> Image: - result = self.session.query(Image).get(image_id) - - if result is None: - raise ObjectNotFoundException(f'image with id {image_id} not found') - - return result + return self.fetch_resource_by_id(image_id) def fetch_images(self) -> list[Image]: return self.session.query(Image).all() @@ -40,13 +35,10 @@ def fetch_images(self) -> list[Image]: def fetch_image_by_url(self, image_url: str) -> Image: image = self.session.query(Image).filter(Image.external_url == image_url).first() - if image is None: - raise ObjectNotFoundException(f'image with url {image_url} not found') - - return image + return self._or_raise(image, f'image with url {image_url} not found') def fetch_outdated_images(self) -> list[Image]: - return self.session.query(Image).filter( + query = self.session.query(Image).filter( or_( Image.last_download.is_(None), Image.last_download < Image.modified, @@ -54,5 +46,7 @@ def fetch_outdated_images(self) -> list[Image]: Image.external_url.isnot(None), ) + return query.all() + def save_image(self, image: Image, *, commit: bool = True) -> None: self._save_resources(image, commit=commit) diff --git a/webapp/repositories/location_repository.py b/webapp/repositories/location_repository.py index d39ce5ff..a18ba0ee 100644 --- a/webapp/repositories/location_repository.py +++ b/webapp/repositories/location_repository.py @@ -16,10 +16,8 @@ along with this program. If not, see . """ -from typing import List, Optional - from mercantile import LngLatBbox -from sqlalchemy import func +from sqlalchemy import func, text from sqlalchemy.orm import joinedload, selectinload from validataclass_search_queries.pagination import PaginatedResult from validataclass_search_queries.search_queries import BaseSearchQuery @@ -27,24 +25,25 @@ from webapp.common.sqlalchemy import Query from webapp.models import Business, Evse, Location -from .base_repository import BaseRepository, ObjectNotFoundException +from .base_repository import BaseRepository +from .exceptions import ObjectNotFoundException class LocationRepository(BaseRepository[Location]): model_cls = Location - def fetch_locations_by_source(self, source: str, include_children: bool = True) -> List[Location]: - locations = self.session.query(Location) + def fetch_locations_by_source(self, source: str, include_children: bool = True) -> list[Location]: + query = self.session.query(Location) if include_children: - locations = locations.options([ + query = query.options( selectinload(Location.evses).selectinload(Evse.connectors), selectinload(Location.operator), - ]) + ) - return locations.filter(Location.source == source).all() + return query.filter(Location.source == source).all() - def fetch_location_ids_by_source(self, source: str) -> List[int]: + def fetch_location_ids_by_source(self, source: str) -> list[int]: items = self.session.query(Location.id).filter(Location.source == source).all() return [item.id for item in items] @@ -53,10 +52,10 @@ def fetch_location_by_id(self, location_id: int, *, include_children: bool = Fal location = self.session.query(Location) if include_children: - location = location.options([ + location = location.options( selectinload(Location.evses).selectinload(Evse.connectors), selectinload(Location.operator), - ]) + ) location = location.get(location_id) @@ -66,33 +65,29 @@ def fetch_location_by_id(self, location_id: int, *, include_children: bool = Fal return location def fetch_location_by_uid(self, source: str, location_uid: str, *, include_children: bool = False) -> Location: - location = self.session.query(Location) + query = self.session.query(Location) + if include_children: - location = location.options([ + query = query.options( selectinload(Location.evses).selectinload(Evse.connectors), selectinload(Location.operator).selectinload(Business.logo), selectinload(Location.suboperator).selectinload(Business.logo), selectinload(Location.owner).selectinload(Business.logo), selectinload(Location.images), selectinload(Location.evses).selectinload(Evse.images), - ]) - - location = location.filter(Location.uid == location_uid).first() + ) - if location is None: - raise ObjectNotFoundException(message=f'location with uid {location_uid} and source {source} not found') + location = query.filter(Location.uid == location_uid).first() - return location + return self._or_raise(location, f'location with uid {location_uid} and source {source} not found') def save_location(self, location: Location, *, commit: bool = True): - self.session.add(location) - if commit: - self.session.commit() + self._save_resources(location, commit=commit) def fetch_locations_summary_by_bounds( self, bbox: LngLatBbox, - static: Optional[bool] = None, + static: bool | None = None, filter_duplicates: bool = True, ) -> list: additional_where = '' @@ -121,9 +116,9 @@ def fetch_locations_summary_by_bounds( query += f'{additional_where} GROUP BY location.id' - return list(self.session.execute(query)) + return list(self.session.execute(text(query))) - def fetch_locations_by_bounds(self, bbox: LngLatBbox) -> List[Location]: + def fetch_locations_by_bounds(self, bbox: LngLatBbox) -> list[Location]: locations = self.session.query(Location) if self.session.connection().dialect.name == 'postgresql': @@ -158,7 +153,7 @@ def delete_location(self, location: Location, *, commit: bool = True): if commit: self.session.commit() - def fetch_locations(self, search_query: Optional[BaseSearchQuery] = None) -> PaginatedResult[Location]: + def fetch_locations(self, search_query: BaseSearchQuery | None = None) -> PaginatedResult[Location]: options = [ selectinload(Location.images), selectinload(Location.evses).selectinload(Evse.connectors), @@ -175,7 +170,7 @@ def fetch_locations(self, search_query: Optional[BaseSearchQuery] = None) -> Pag query = self.session.query(Location).options(*options) return self._search_and_paginate(query, search_query) - def _filter_by_search_query(self, query: Query, search_query: Optional[BaseSearchQuery]) -> Query: + def _filter_by_search_query(self, query: Query, search_query: BaseSearchQuery | None) -> Query: if search_query is None: return query diff --git a/webapp/repositories/source_repository.py b/webapp/repositories/source_repository.py index 3dd21f7f..bf04f620 100644 --- a/webapp/repositories/source_repository.py +++ b/webapp/repositories/source_repository.py @@ -16,8 +16,6 @@ along with this program. If not, see . """ -from typing import List - from webapp.models import Source from .base_repository import BaseRepository @@ -26,7 +24,7 @@ class SourceRepository(BaseRepository[Source]): model_cls = Source - def fetch_sources(self) -> List[Source]: + def fetch_sources(self) -> list[Source]: return self.session.query(Source).all() def fetch_source_by_uid(self, source_uid: str) -> Source: