diff --git a/src/rejax/algos/algorithm.py b/src/rejax/algos/algorithm.py index 42cfa33..63d105a 100644 --- a/src/rejax/algos/algorithm.py +++ b/src/rejax/algos/algorithm.py @@ -35,6 +35,7 @@ class Algorithm(struct.PyTreeNode): @classmethod def create(cls, **config): + config = deepcopy(config) env, env_params = cls.create_env(config) agent = cls.create_agent(config, env, env_params) diff --git a/tests/test_configs.py b/tests/test_configs.py index 49248a7..d18d04e 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -1,6 +1,6 @@ import os import unittest - +from copy import deepcopy from yaml import safe_load from rejax import get_algo @@ -31,3 +31,20 @@ def test_configs(self) -> None: f"Failed to create {algo} with config '{config_path}': " f"{type(e).__name__}: {str(e)}" ) + + def test_create_does_not_modify_config(self) -> None: + for config_path, configs_env in self.configs.items(): + for algo, config in configs_env.items(): + if config.get("env", "").startswith("navix"): + continue + with self.subTest(config_opath=config_path, algo=algo): + try: + original_config = deepcopy(config) + algo_cls = get_algo(algo) + algo_cls.create(**config) + self.assertEqual(config, original_config) + except Exception as e: + self.fail( + f"Config '{config_path}' for {algo} has been modified: " + f"{type(e).__name__}: {str(e)}" + )