diff --git a/app/api/v2/managers/base_api_manager.py b/app/api/v2/managers/base_api_manager.py index f72368790..83da39b04 100644 --- a/app/api/v2/managers/base_api_manager.py +++ b/app/api/v2/managers/base_api_manager.py @@ -1,5 +1,6 @@ import logging import os +import re import uuid import yaml @@ -7,6 +8,7 @@ from typing import Any, List from base64 import b64encode, b64decode +from app.api.v2.errors import DataValidationError from app.utility.base_world import BaseWorld @@ -64,6 +66,7 @@ def create_object_from_schema(self, schema: SchemaMeta, data: dict, access: Base async def create_on_disk_object(self, data: dict, access: dict, ram_key: str, id_property: str, obj_class: type): obj_id = data.get(id_property) or str(uuid.uuid4()) + obj_id = self._sanitize_id(obj_id) data[id_property] = obj_id file_path = await self._get_new_object_file_path(data[id_property], ram_key) @@ -121,18 +124,34 @@ async def remove_object_from_memory_by_id(self, identifier: str, ram_key: str, i await self._data_svc.remove(ram_key, {id_property: identifier}) async def remove_object_from_disk_by_id(self, identifier: str, ram_key: str): + identifier = self._sanitize_id(identifier) file_path = await self._get_existing_object_file_path(identifier, ram_key) if os.path.exists(file_path): os.remove(file_path) + @staticmethod + def _sanitize_id(obj_id) -> str: + '''Removes any non-alphanumeric characters and non-hyphen/underscore.''' + if not isinstance(obj_id, str): + raise DataValidationError(message=f'Invalid id type: expected str, got {type(obj_id).__name__}', name='id', value=obj_id) + original_id = obj_id + obj_id = re.sub(r'[^a-zA-Z0-9_-]', '', obj_id) + if not obj_id: + raise DataValidationError(message=f"Invalid id: {obj_id!r}", name='id', value=obj_id) + if original_id != obj_id: + logging.getLogger(DEFAULT_LOGGER_NAME).warning(f"Sanitized ID: {obj_id}") + return obj_id + @staticmethod async def _get_new_object_file_path(identifier: str, ram_key: str) -> str: """Create file path for new object""" + identifier = BaseApiManager._sanitize_id(identifier) return os.path.join('data', ram_key, f'{identifier}.yml') async def _get_existing_object_file_path(self, identifier: str, ram_key: str) -> str: """Find file path for existing object (by id)""" + identifier = self._sanitize_id(identifier) _, file_path = await self._file_svc.find_file_path(f'{identifier}.yml', location=ram_key) if not file_path: file_path = await self._get_new_object_file_path(identifier, ram_key) diff --git a/tests/api/v2/managers/test_base_api_manager.py b/tests/api/v2/managers/test_base_api_manager.py index 7a6495cf1..3cf4560e5 100644 --- a/tests/api/v2/managers/test_base_api_manager.py +++ b/tests/api/v2/managers/test_base_api_manager.py @@ -1,5 +1,7 @@ +import pytest import marshmallow as ma +from app.api.v2.errors import DataValidationError from app.api.v2.managers.base_api_manager import BaseApiManager from app.objects.interfaces.i_object import FirstClassObjectInterface from app.utility.base_object import BaseObject @@ -248,3 +250,28 @@ def test_replace_object(agent): assert len(stub_data_svc.ram['tests']) == 1 assert not stub_data_svc.ram['tests'][0].value + + +def test_sanitize_id(): + valid = '766be199-7316-4b26-b3db-e272aaf7e0d4' + assert valid == BaseApiManager._sanitize_id(valid) + assert valid.upper() == BaseApiManager._sanitize_id(valid.upper()) + assert valid == BaseApiManager._sanitize_id('../.././&%$!"#766be19:9-73[16-]4b}26-b{3d!b-e272..\\//aaf/*7e0d4') + assert 'testid123TEST' == BaseApiManager._sanitize_id('testid123TEST') + with pytest.raises(DataValidationError): + BaseApiManager._sanitize_id('../../.') + with pytest.raises(DataValidationError): + BaseApiManager._sanitize_id('') + # Non-string IDs should raise a DataValidationError + with pytest.raises(DataValidationError): + BaseApiManager._sanitize_id(12345) + + +def test_sanitize_id_logs_warning_when_changed(caplog): + # Capture warnings when an ID is mutated by sanitization + caplog.set_level('WARNING') + original = 'abc/def?ghi' + sanitized = BaseApiManager._sanitize_id(original) + assert sanitized == 'abcdefghi' + # Ensure a warning was emitted that includes the sanitized ID + assert any('Sanitized ID' in rec.getMessage() and sanitized in rec.getMessage() for rec in caplog.records)