diff --git a/src/config.rs b/src/config.rs index a9d6b9c..d8a4ee6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,6 +9,7 @@ use std::path::Path; /// Core structure of the config #[pyclass] +#[derive(Clone)] pub struct Config { /// The size of the memory (history). /// The number of elements that should be tracked. @@ -17,7 +18,7 @@ pub struct Config { pub page_size: Option, /// Whether the predicate should be automatically committed. pub auto_commit: Option, - config: NativeConfig, + pub(crate) engine: NativeConfig, } #[pymethods] @@ -26,46 +27,29 @@ impl Config { #[new] pub fn new(config_file: String) -> PyResult { let config_file = Path::new(&config_file); - let config = NativeConfig::from_file(config_file).map_err(|err| { + let engine = NativeConfig::from_file(config_file).map_err(|err| { pyo3::exceptions::PyValueError::new_err(format!( "Failed to load config file `{config_file:?}`.\nCaused by:\n\t{err}" )) })?; Ok(Self { - buffer_size: config.core.as_ref().and_then(|c| c.buffer_size), - page_size: config.core.as_ref().and_then(|c| c.page_size), - auto_commit: config.core.as_ref().and_then(|c| c.auto_commit), - config, + buffer_size: engine.core.as_ref().and_then(|c| c.buffer_size), + page_size: engine.core.as_ref().and_then(|c| c.page_size), + auto_commit: engine.core.as_ref().and_then(|c| c.auto_commit), + engine, }) } /// Extracts the data from the configuration. pub fn extract_data(&self, py: Python) -> PyResult { - let value = serde_json::to_value(self.config.extract_data()).unwrap(); + let value = serde_json::to_value(self.engine.extract_data()).unwrap(); pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } /// Extracts the translation from the configuration. pub fn extract_translation(&self, py: Python) -> PyResult { - let value = serde_json::to_value(self.config.extract_translation()).unwrap(); - - pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) - } - - /// Extracts the translators from the configuration. - #[cfg(feature = "rhai")] - pub fn extract_translators(&self, py: Python) -> PyResult { - let translators = self.config.extract_translators().map_err(|err| { - pyo3::exceptions::PyValueError::new_err(format!( - "Failed to load the translators`.\nCaused by:\n\t{err}" - )) - })?; - let translators = translators - .into_iter() - .map(|(k, v)| (k, format!("{v:?}"))) - .collect::>(); - let value = serde_json::to_value(&translators).unwrap(); + let value = serde_json::to_value(self.engine.extract_translation()).unwrap(); pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } diff --git a/src/translator.rs b/src/translator.rs index dfcbce3..192d092 100644 --- a/src/translator.rs +++ b/src/translator.rs @@ -1,6 +1,7 @@ #![deny(missing_docs)] //! Python binding of the afrim translator. +use crate::config::Config; #[cfg(feature = "rhai")] use afrim_translator::Engine; use afrim_translator::Translator as NativeTranslator; @@ -46,6 +47,22 @@ impl Translator { Ok(()) } + /// Register Rhai translators from a config object (requires `rhai` feature) + #[cfg(feature = "rhai")] + fn register_from_config(&mut self, config: Config) -> PyResult<()> { + // Extracts the translators from the configuration. + let translators = config.engine.extract_translators().map_err(|err| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to load the translators`.\nCaused by:\n\t{err}" + )) + })?; + + translators + .into_iter() + .for_each(|(name, ast)| self.engine.register(name, ast)); + Ok(()) + } + /// Unregister (requires `rhai` feature) #[cfg(feature = "rhai")] fn unregister(&mut self, name: &str) { diff --git a/tests/test_config.py b/tests/test_config.py index 56a33c1..103b008 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -40,24 +40,6 @@ def test_config_file_with_translators(self): with pytest.raises(ValueError, match="Failed to load config file"): Config(DATA_DIR + "/invalid_translator.toml") - # - config = Config(DATA_DIR + "/config_sample.toml") - assert isinstance(config.extract_translators(), dict) - - # no translators - config = Config(DATA_DIR + "/blank_sample.toml") - assert config.extract_translators() == {} - - # invalid script - config = Config(DATA_DIR + "/bad_script2.toml") - with pytest.raises(ValueError, match="Failed to load the translators"): - config.extract_translators() - - # script not found - config = Config(DATA_DIR + "/bad_script.toml") - with pytest.raises(ValueError, match="Failed to load the translators"): - config.extract_translators() - def test_config_file_with_translation(self): config = Config(DATA_DIR + "/config_sample.toml") assert isinstance(config.extract_translation(), dict) diff --git a/tests/test_translator.py b/tests/test_translator.py index 3802889..7a232e9 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -1,7 +1,12 @@ """Tests for Translator functionality.""" import pytest -from afrim_py import Translator, is_rhai_feature_enabled +from afrim_py import Translator, Config, is_rhai_feature_enabled +from pathlib import Path + + +BASE_DIR = Path(__file__).resolve().parent.parent +DATA_DIR = (BASE_DIR / "data").as_posix() class TestTranslator: @@ -249,6 +254,26 @@ def test_register_unregister_script(self): translator.register("test_script", "fn main(input) { [input] }") translator.unregister("test_script") + def test_register_from_config(self): + translator = Translator({}, True) + + config = Config(DATA_DIR + "/config_sample.toml") + translator.register_from_config(config) + + # no translators + config = Config(DATA_DIR + "/blank_sample.toml") + translator.register_from_config(config) + + # invalid script + config = Config(DATA_DIR + "/bad_script2.toml") + with pytest.raises(ValueError, match="Failed to load the translators"): + translator.register_from_config(config) + + # script not found + config = Config(DATA_DIR + "/bad_script.toml") + with pytest.raises(ValueError, match="Failed to load the translators"): + translator.register_from_config(config) + def test_whitespace_handling(self): """Test translation with whitespace in keys and values.""" dictionary = {