Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::path::Path;

/// Core structure of the config
#[pyclass]
#[derive(Clone)]
pub struct Config {
/// The size of the memory (history).
/// The number of elements that should be tracked.
Expand All @@ -17,7 +18,7 @@ pub struct Config {
pub page_size: Option<usize>,
/// Whether the predicate should be automatically committed.
pub auto_commit: Option<bool>,
config: NativeConfig,
pub(crate) engine: NativeConfig,
}

#[pymethods]
Expand All @@ -26,46 +27,29 @@ impl Config {
#[new]
pub fn new(config_file: String) -> PyResult<Self> {
let config_file = Path::new(&config_file);
let config = NativeConfig::from_file(config_file).map_err(|err| {
let engine = NativeConfig::from_file(config_file).map_err(|err| {
pyo3::exceptions::PyValueError::new_err(format!(
"Failed to load config file `{config_file:?}`.\nCaused by:\n\t{err}"
))
})?;

Ok(Self {
buffer_size: config.core.as_ref().and_then(|c| c.buffer_size),
page_size: config.core.as_ref().and_then(|c| c.page_size),
auto_commit: config.core.as_ref().and_then(|c| c.auto_commit),
config,
buffer_size: engine.core.as_ref().and_then(|c| c.buffer_size),
page_size: engine.core.as_ref().and_then(|c| c.page_size),
auto_commit: engine.core.as_ref().and_then(|c| c.auto_commit),
engine,
})
}

/// Extracts the data from the configuration.
pub fn extract_data(&self, py: Python) -> PyResult<PyObject> {
let value = serde_json::to_value(self.config.extract_data()).unwrap();
let value = serde_json::to_value(self.engine.extract_data()).unwrap();
pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
}

/// Extracts the translation from the configuration.
pub fn extract_translation(&self, py: Python) -> PyResult<PyObject> {
let value = serde_json::to_value(self.config.extract_translation()).unwrap();

pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
}

/// Extracts the translators from the configuration.
#[cfg(feature = "rhai")]
pub fn extract_translators(&self, py: Python) -> PyResult<PyObject> {
let translators = self.config.extract_translators().map_err(|err| {
pyo3::exceptions::PyValueError::new_err(format!(
"Failed to load the translators`.\nCaused by:\n\t{err}"
))
})?;
let translators = translators
.into_iter()
.map(|(k, v)| (k, format!("{v:?}")))
.collect::<HashMap<String, String>>();
let value = serde_json::to_value(&translators).unwrap();
let value = serde_json::to_value(self.engine.extract_translation()).unwrap();

pythonize(py, &value).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
}
Expand Down
17 changes: 17 additions & 0 deletions src/translator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(missing_docs)]
//! Python binding of the afrim translator.

use crate::config::Config;
#[cfg(feature = "rhai")]
use afrim_translator::Engine;
use afrim_translator::Translator as NativeTranslator;
Expand Down Expand Up @@ -46,6 +47,22 @@ impl Translator {
Ok(())
}

/// Register Rhai translators from a config object (requires `rhai` feature)
#[cfg(feature = "rhai")]
fn register_from_config(&mut self, config: Config) -> PyResult<()> {
// Extracts the translators from the configuration.
let translators = config.engine.extract_translators().map_err(|err| {
pyo3::exceptions::PyValueError::new_err(format!(
"Failed to load the translators`.\nCaused by:\n\t{err}"
))
})?;

translators
.into_iter()
.for_each(|(name, ast)| self.engine.register(name, ast));
Ok(())
}

/// Unregister (requires `rhai` feature)
#[cfg(feature = "rhai")]
fn unregister(&mut self, name: &str) {
Expand Down
18 changes: 0 additions & 18 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ def test_config_file_with_translators(self):
with pytest.raises(ValueError, match="Failed to load config file"):
Config(DATA_DIR + "/invalid_translator.toml")

#
config = Config(DATA_DIR + "/config_sample.toml")
assert isinstance(config.extract_translators(), dict)

# no translators
config = Config(DATA_DIR + "/blank_sample.toml")
assert config.extract_translators() == {}

# invalid script
config = Config(DATA_DIR + "/bad_script2.toml")
with pytest.raises(ValueError, match="Failed to load the translators"):
config.extract_translators()

# script not found
config = Config(DATA_DIR + "/bad_script.toml")
with pytest.raises(ValueError, match="Failed to load the translators"):
config.extract_translators()

def test_config_file_with_translation(self):
config = Config(DATA_DIR + "/config_sample.toml")
assert isinstance(config.extract_translation(), dict)
Expand Down
27 changes: 26 additions & 1 deletion tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Tests for Translator functionality."""

import pytest
from afrim_py import Translator, is_rhai_feature_enabled
from afrim_py import Translator, Config, is_rhai_feature_enabled
from pathlib import Path


BASE_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = (BASE_DIR / "data").as_posix()


class TestTranslator:
Expand Down Expand Up @@ -249,6 +254,26 @@ def test_register_unregister_script(self):
translator.register("test_script", "fn main(input) { [input] }")
translator.unregister("test_script")

def test_register_from_config(self):
translator = Translator({}, True)

config = Config(DATA_DIR + "/config_sample.toml")
translator.register_from_config(config)

# no translators
config = Config(DATA_DIR + "/blank_sample.toml")
translator.register_from_config(config)

# invalid script
config = Config(DATA_DIR + "/bad_script2.toml")
with pytest.raises(ValueError, match="Failed to load the translators"):
translator.register_from_config(config)

# script not found
config = Config(DATA_DIR + "/bad_script.toml")
with pytest.raises(ValueError, match="Failed to load the translators"):
translator.register_from_config(config)

def test_whitespace_handling(self):
"""Test translation with whitespace in keys and values."""
dictionary = {
Expand Down
Loading