diff --git a/agrifoodpy/pipeline/__init__.py b/agrifoodpy/pipeline/__init__.py index 26952e7..8982719 100644 --- a/agrifoodpy/pipeline/__init__.py +++ b/agrifoodpy/pipeline/__init__.py @@ -2,5 +2,4 @@ This module provides methods to build a pipeline for the AgriFoodPy package. """ -from .pipeline import * -from ..utils.dict_utils import * \ No newline at end of file +from .pipeline import * \ No newline at end of file diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index 9104a61..08fda2f 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -10,7 +10,7 @@ import time import yaml import importlib - +from ..utils.dict_utils import get_dict, set_dict class Pipeline(): '''Class for constructing and running pipelines of functions with @@ -80,7 +80,7 @@ def datablock_write(self, path, value): current = current.setdefault(key, {}) current[path[-1]] = value - def add_node(self, node, params={}, name=None, index=None): + def add_node(self, node, params=None, name=None, index=None): """Adds a node to the pipeline, including its function and execution parameters. @@ -99,7 +99,7 @@ def add_node(self, node, params={}, name=None, index=None): """ # Copy the parameters to avoid modifying the original dictionaries - params = copy.deepcopy(params) + params = copy.deepcopy(params) if params is not None else {} if name is None: name = "Node {}".format(len(self.nodes) + 1) @@ -293,3 +293,81 @@ def wrapper(*args, **kwargs): return result return wrapper return pipeline_decorator + + +def pipeline_node(input_keys=None): + """ Decorator to make a function compatible with pipeline execution + + If a datablock is passed as a kwarg, the function will be executed in + pipeline mode, and the values of the parameters named in input_keys will + be interpreted as datablock lookup keys. The corresponding objects will be + extracted from the datablock and passed to the function. Unregistered + keyword arguments will be passed directly to the function. The decorated + function takes a "return_key" kwarg to specify the key under which the + function output will be stored in the datablock. If not provided, the + function name will be used as the return key. + + Parameters + ---------- + input_keys: string or list of strings, optional + List of decorated function parameter names whose values will be used as + datablock lookup keys in pipeline mode. + + Returns + ------- + wrapper: function + The decorated function + """ + + if input_keys is not None: + if isinstance(input_keys, str): + input_keys = [input_keys] + else: + input_keys = [] + + def pipeline_decorator(func): + reserved = {"datablock", "return_key"} + if reserved & set(signature(func).parameters): + raise ValueError(f"Function {func.__name__} has reserved parameter" + f" names {reserved & set(signature(func).parameters)}." + "Please rename these parameters to use the" + "pipeline_node decorator.") + + func_params = signature(func).parameters + unknown = set(input_keys) - set(func_params.keys()) + if unknown: + raise ValueError(f"input_keys {unknown} not found in parameters " + f"of '{func.__name__}'") + + @wraps(func) + def wrapper(*args, **kwargs): + + # Pop wrapper-specific kwargs + datablock = kwargs.pop("datablock", None) + return_key = kwargs.pop("return_key", func.__name__) + + # Bind positional and keyword args to their parameter names + func_sig = signature(func) + try: + bound = func_sig.bind(*args, **kwargs) + except TypeError as e: + raise TypeError( + f"Invalid arguments for function {func.__name__}." + ) from e + + bound.apply_defaults() + + if datablock is None: + return func(*bound.args, **bound.kwargs) + + else: + for key in input_keys: + bound.arguments[key] = get_dict(datablock, + bound.arguments[key]) + result = func(*bound.args, **bound.kwargs) + + set_dict(datablock, return_key, result) + + return datablock + return wrapper + return pipeline_decorator \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/test_pipeline.py b/agrifoodpy/pipeline/tests/test_pipeline.py index 4e05d49..8170a81 100644 --- a/agrifoodpy/pipeline/tests/test_pipeline.py +++ b/agrifoodpy/pipeline/tests/test_pipeline.py @@ -1,4 +1,5 @@ -from agrifoodpy.pipeline.pipeline import Pipeline, standalone +from agrifoodpy.pipeline import Pipeline, standalone +import pytest def test_init(): pipeline = Pipeline() @@ -202,4 +203,128 @@ def pipeline_decorated(x, out_key, datablock=None): pipeline.add_node(pipeline_decorated, params={'x': 'x', 'out_key': 'result'}) pipeline.run() - assert pipeline.datablock['result'] == 15 \ No newline at end of file + assert pipeline.datablock['result'] == 15 + +def test_pipeline_node_decorator(): + + from agrifoodpy.pipeline.pipeline import Pipeline, pipeline_node + + test_datablock_single = {'value1': 5, 'value2': 10} + test_pipeline_single = Pipeline(test_datablock_single) + + # Test decorated function with single input key and no return key + @pipeline_node('x') + def double_numbers(x): + return x * 2 + + test_pipeline_single.add_node( + double_numbers, + params={'x': 'value1'} + ) + + test_pipeline_single.run() + assert double_numbers(test_datablock_single['value1']) == 10 + assert double_numbers.__name__ in test_pipeline_single.datablock + assert test_pipeline_single.datablock[double_numbers.__name__] == 10 + + # Test decorated function with single input key and unregistered key + @pipeline_node('value') + def scale_numbers(value, factor=3): + return value * factor + + test_datablock_mixed = {'value': 5} + + test_pipeline_mixed = Pipeline(test_datablock_mixed) + test_pipeline_mixed.add_node( + scale_numbers, + params={'value': 'value', 'factor': 4} + ) + test_pipeline_mixed.run() + assert scale_numbers(test_datablock_mixed['value'], factor=4) == 20 + assert scale_numbers.__name__ in test_pipeline_mixed.datablock + assert test_pipeline_mixed.datablock[scale_numbers.__name__] == 20 + + # Test decorated function with multiple input keys and no return key + test_datablock_multiple = {'value1': 5, 'value2': 10} + test_pipeline_multiple = Pipeline(test_datablock_multiple) + + @pipeline_node(['x', 'y']) + def sum_numbers(x, y): + return x + y + + test_pipeline_multiple.add_node( + sum_numbers, + params={'x': 'value1', 'y': 'value2'} + ) + + test_pipeline_multiple.run() + assert sum_numbers( + test_datablock_multiple['value1'], + test_datablock_multiple['value2']) == 15 + assert sum_numbers.__name__ in test_pipeline_multiple.datablock + assert test_pipeline_multiple.datablock[sum_numbers.__name__] == 15 + + # Test decorated function with multiple input keys and return key + test_datablock_with_return = {'value1': 5, 'value2': 10} + test_pipeline_with_return = Pipeline(test_datablock_with_return) + return_key = "result" + + @pipeline_node(['x', 'y']) + def subtract_numbers(x, y): + return x - y + + test_pipeline_with_return.add_node( + subtract_numbers, + params={'x': 'value1', 'y': 'value2', "return_key": return_key} + ) + + test_pipeline_with_return.run() + assert subtract_numbers( + test_datablock_with_return['value1'], + test_datablock_with_return['value2']) == -5 + assert return_key in test_pipeline_with_return.datablock + assert test_pipeline_with_return.datablock[return_key] == -5 + + #test decorated function with external function + test_datablock_external = {'value1': [1, 2, 3]} + test_pipeline_external = Pipeline(test_datablock_external) + + import numpy as np + + test_pipeline_external.add_node( + pipeline_node(input_keys="a")(np.mean), + params={'a': 'value1', 'return_key': "mean_result"} + ) + + test_pipeline_external.run() + assert np.mean(test_datablock_external['value1']) == 2 + assert "mean_result" in test_pipeline_external.datablock + assert test_pipeline_external.datablock["mean_result"] == 2 + + # Test decorated function with no input keys + test_pipeline_no_input = Pipeline() + + @pipeline_node([]) + def return_constant(): + return 42 + + test_pipeline_no_input.add_node( + return_constant + ) + + test_pipeline_no_input.run() + assert return_constant() == 42 + assert return_constant.__name__ in test_pipeline_no_input.datablock + assert test_pipeline_no_input.datablock[return_constant.__name__] == 42 + + # Test decorated function with reserved parameter names + with pytest.raises(ValueError, match="reserved parameter names.*datablock"): + @pipeline_node(['x']) + def reserved_param_node(x, datablock=None): + pass + + # Test decorated function with unknown input keys + with pytest.raises(ValueError, match="input_keys.*not found in parameters"): + @pipeline_node(['wrong_key']) + def unknown_input_node(right_key): + pass