diff --git a/tests/api/v2/handlers/test_abilities_api.py b/tests/api/v2/handlers/test_abilities_api.py index ee25f040c..2d223fc29 100644 --- a/tests/api/v2/handlers/test_abilities_api.py +++ b/tests/api/v2/handlers/test_abilities_api.py @@ -1,3 +1,4 @@ +import os import pytest from http import HTTPStatus @@ -11,23 +12,31 @@ @pytest.fixture def new_ability_payload(): test_executor_linux = Executor(name='sh', platform='linux', command='whoami') - return {'name': 'new test ability', - 'ability_id': '456', - 'tactic': 'collection', - 'technique_name': 'collection', - 'technique_id': '1', - 'executors': [ExecutorSchema().dump(test_executor_linux)], - 'access': {}, - 'additional_info': {}, - 'buckets': ['collection'], - 'description': '', - 'privilege': '', - 'repeatable': False, - 'requirements': [], - 'singleton': False, - 'plugin': '', - 'delete_payload': True, - } + yield { + 'name': 'new test ability', + 'ability_id': '456', + 'tactic': 'collection', + 'technique_name': 'collection', + 'technique_id': '1', + 'executors': [ExecutorSchema().dump(test_executor_linux)], + 'access': {}, + 'additional_info': {}, + 'buckets': ['collection'], + 'description': '', + 'privilege': '', + 'repeatable': False, + 'requirements': [], + 'singleton': False, + 'plugin': '', + 'delete_payload': True, + } + + # Ability cleanup + if os.path.exists('data/abilities/collection/456.yml'): + try: + os.remove('data/abilities/collection/456.yml') + except OSError: + pass @pytest.fixture @@ -56,7 +65,14 @@ def test_ability(event_loop, api_v2_client, executor): technique_name='collection', technique_id='1', description='', privilege='', tactic='discovery', plugin='testplugin') event_loop.run_until_complete(BaseService.get_service('data_svc').store(ability)) - return ability + yield ability + + # cleanup + if os.path.exists('data/abilities/collection/123.yml'): + try: + os.remove('data/abilities/collection/123.yml') + except OSError: + pass class TestAbilitiesApi: diff --git a/tests/api/v2/handlers/test_adversaries_api.py b/tests/api/v2/handlers/test_adversaries_api.py index 18ccf481c..f3af218f5 100644 --- a/tests/api/v2/handlers/test_adversaries_api.py +++ b/tests/api/v2/handlers/test_adversaries_api.py @@ -1,3 +1,4 @@ +import os import pytest from http import HTTPStatus @@ -6,6 +7,15 @@ from app.utility.base_service import BaseService +def adversary_file_cleanup(adversary_id): + file_path = f'data/adversaries/{adversary_id}.yml' + if os.path.exists(file_path): + try: + os.remove(file_path) + except OSError: + pass + + @pytest.fixture def updated_adversary_payload(): return { @@ -33,7 +43,7 @@ def expected_updated_adversary_dump(test_adversary, updated_adversary_payload): @pytest.fixture def new_adversary_payload(): - return { + yield { 'name': 'test new adversary', 'description': 'a new adversary', 'adversary_id': '456', @@ -43,6 +53,8 @@ def new_adversary_payload(): 'plugin': '' } + adversary_file_cleanup('456') + @pytest.fixture def expected_new_adversary_dump(new_adversary_payload): @@ -52,7 +64,7 @@ def expected_new_adversary_dump(new_adversary_payload): @pytest.fixture def new_adversary_duplicate_id_payload(): - return { + yield { 'name': 'test new adversary', 'description': 'a new adversary with an invalid payload', 'adversary_id': '456', @@ -63,6 +75,8 @@ def new_adversary_duplicate_id_payload(): 'plugin': '' } + adversary_file_cleanup('456') + @pytest.fixture def expected_new_duplicate_id_adversary_dump(new_adversary_duplicate_id_payload): @@ -83,7 +97,9 @@ def test_adversary(event_loop): 'plugin': ''} test_adversary = AdversarySchema().load(expected_adversary) event_loop.run_until_complete(BaseService.get_service('data_svc').store(test_adversary)) - return test_adversary + yield test_adversary + + adversary_file_cleanup(test_adversary.adversary_id) class TestAdversariesApi: diff --git a/tests/api/v2/handlers/test_planners_api.py b/tests/api/v2/handlers/test_planners_api.py index 09b371115..991f5696c 100644 --- a/tests/api/v2/handlers/test_planners_api.py +++ b/tests/api/v2/handlers/test_planners_api.py @@ -1,3 +1,4 @@ +import os import pytest from http import HTTPStatus @@ -10,7 +11,14 @@ def test_planner(event_loop, api_v2_client): planner = Planner(name="123test planner", planner_id="123", description="a test planner", plugin="planner") event_loop.run_until_complete(BaseService.get_service('data_svc').store(planner)) - return planner + yield planner + + # Planner cleanup + if os.path.exists('data/planners/123.yml'): + try: + os.remove('data/planners/123.yml') + except OSError: + pass @pytest.fixture