Skip to content
Merged
19 changes: 19 additions & 0 deletions app/api/v2/managers/base_api_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
import os
import re
import uuid
import yaml

from marshmallow.schema import SchemaMeta
from typing import Any, List
from base64 import b64encode, b64decode

from app.api.v2.errors import DataValidationError
from app.utility.base_world import BaseWorld


Expand Down Expand Up @@ -64,6 +66,7 @@ def create_object_from_schema(self, schema: SchemaMeta, data: dict, access: Base

async def create_on_disk_object(self, data: dict, access: dict, ram_key: str, id_property: str, obj_class: type):
obj_id = data.get(id_property) or str(uuid.uuid4())
obj_id = self._sanitize_id(obj_id)
data[id_property] = obj_id

file_path = await self._get_new_object_file_path(data[id_property], ram_key)
Expand Down Expand Up @@ -121,18 +124,34 @@ async def remove_object_from_memory_by_id(self, identifier: str, ram_key: str, i
await self._data_svc.remove(ram_key, {id_property: identifier})

async def remove_object_from_disk_by_id(self, identifier: str, ram_key: str):
identifier = self._sanitize_id(identifier)
file_path = await self._get_existing_object_file_path(identifier, ram_key)

if os.path.exists(file_path):
os.remove(file_path)

@staticmethod
def _sanitize_id(obj_id) -> str:
'''Removes any non-alphanumeric characters and non-hyphen/underscore.'''
if not isinstance(obj_id, str):
raise DataValidationError(message=f'Invalid id type: expected str, got {type(obj_id).__name__}', name='id', value=obj_id)
original_id = obj_id
obj_id = re.sub(r'[^a-zA-Z0-9_-]', '', obj_id)
if not obj_id:
raise DataValidationError(message=f"Invalid id: {obj_id!r}", name='id', value=obj_id)
if original_id != obj_id:
logging.getLogger(DEFAULT_LOGGER_NAME).warning(f"Sanitized ID: {obj_id}")
return obj_id
Comment thread
uruwhy marked this conversation as resolved.

@staticmethod
async def _get_new_object_file_path(identifier: str, ram_key: str) -> str:
"""Create file path for new object"""
Comment thread
uruwhy marked this conversation as resolved.
identifier = BaseApiManager._sanitize_id(identifier)
return os.path.join('data', ram_key, f'{identifier}.yml')

async def _get_existing_object_file_path(self, identifier: str, ram_key: str) -> str:
"""Find file path for existing object (by id)"""
identifier = self._sanitize_id(identifier)
_, file_path = await self._file_svc.find_file_path(f'{identifier}.yml', location=ram_key)
if not file_path:
file_path = await self._get_new_object_file_path(identifier, ram_key)
Expand Down
27 changes: 27 additions & 0 deletions tests/api/v2/managers/test_base_api_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import marshmallow as ma

from app.api.v2.errors import DataValidationError
from app.api.v2.managers.base_api_manager import BaseApiManager
from app.objects.interfaces.i_object import FirstClassObjectInterface
from app.utility.base_object import BaseObject
Expand Down Expand Up @@ -248,3 +250,28 @@ def test_replace_object(agent):

assert len(stub_data_svc.ram['tests']) == 1
assert not stub_data_svc.ram['tests'][0].value


def test_sanitize_id():
valid = '766be199-7316-4b26-b3db-e272aaf7e0d4'
assert valid == BaseApiManager._sanitize_id(valid)
assert valid.upper() == BaseApiManager._sanitize_id(valid.upper())
assert valid == BaseApiManager._sanitize_id('../.././&%$!"#766be19:9-73[16-]4b}26-b{3d!b-e272..\\//aaf/*7e0d4')
assert 'testid123TEST' == BaseApiManager._sanitize_id('testid123TEST')
with pytest.raises(DataValidationError):
BaseApiManager._sanitize_id('../../.')
with pytest.raises(DataValidationError):
BaseApiManager._sanitize_id('')
# Non-string IDs should raise a DataValidationError
with pytest.raises(DataValidationError):
BaseApiManager._sanitize_id(12345)


def test_sanitize_id_logs_warning_when_changed(caplog):
# Capture warnings when an ID is mutated by sanitization
caplog.set_level('WARNING')
original = 'abc/def?ghi'
sanitized = BaseApiManager._sanitize_id(original)
assert sanitized == 'abcdefghi'
# Ensure a warning was emitted that includes the sanitized ID
assert any('Sanitized ID' in rec.getMessage() and sanitized in rec.getMessage() for rec in caplog.records)
Loading