Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions agrifoodpy/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
This module provides methods to build a pipeline for the AgriFoodPy package.
"""

from .pipeline import *
from ..utils.dict_utils import *
from .pipeline import *
Comment thread
jucordero marked this conversation as resolved.
84 changes: 81 additions & 3 deletions agrifoodpy/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import yaml
import importlib

from ..utils.dict_utils import get_dict, set_dict

class Pipeline():
'''Class for constructing and running pipelines of functions with
Expand Down Expand Up @@ -80,7 +80,7 @@ def datablock_write(self, path, value):
current = current.setdefault(key, {})
current[path[-1]] = value

def add_node(self, node, params={}, name=None, index=None):
def add_node(self, node, params=None, name=None, index=None):
"""Adds a node to the pipeline, including its function and execution
parameters.

Expand All @@ -99,7 +99,7 @@ def add_node(self, node, params={}, name=None, index=None):
"""

# Copy the parameters to avoid modifying the original dictionaries
params = copy.deepcopy(params)
params = copy.deepcopy(params) if params is not None else {}

if name is None:
name = "Node {}".format(len(self.nodes) + 1)
Expand Down Expand Up @@ -293,3 +293,81 @@ def wrapper(*args, **kwargs):
return result
return wrapper
return pipeline_decorator


def pipeline_node(input_keys=None):
""" Decorator to make a function compatible with pipeline execution

If a datablock is passed as a kwarg, the function will be executed in
pipeline mode, and the values of the parameters named in input_keys will
be interpreted as datablock lookup keys. The corresponding objects will be
extracted from the datablock and passed to the function. Unregistered
keyword arguments will be passed directly to the function. The decorated
function takes a "return_key" kwarg to specify the key under which the
function output will be stored in the datablock. If not provided, the
function name will be used as the return key.

Parameters
----------
input_keys: string or list of strings, optional
List of decorated function parameter names whose values will be used as
datablock lookup keys in pipeline mode.

Returns
-------
wrapper: function
The decorated function
"""

if input_keys is not None:
if isinstance(input_keys, str):
input_keys = [input_keys]
else:
input_keys = []

def pipeline_decorator(func):
reserved = {"datablock", "return_key"}
if reserved & set(signature(func).parameters):
raise ValueError(f"Function {func.__name__} has reserved parameter"
f" names {reserved & set(signature(func).parameters)}."
"Please rename these parameters to use the"
"pipeline_node decorator.")

func_params = signature(func).parameters
unknown = set(input_keys) - set(func_params.keys())
if unknown:
raise ValueError(f"input_keys {unknown} not found in parameters "
f"of '{func.__name__}'")

@wraps(func)
def wrapper(*args, **kwargs):

# Pop wrapper-specific kwargs
datablock = kwargs.pop("datablock", None)
return_key = kwargs.pop("return_key", func.__name__)

# Bind positional and keyword args to their parameter names
func_sig = signature(func)
try:
bound = func_sig.bind(*args, **kwargs)
except TypeError as e:
raise TypeError(
f"Invalid arguments for function {func.__name__}."
) from e

bound.apply_defaults()

if datablock is None:
return func(*bound.args, **bound.kwargs)

else:
for key in input_keys:
bound.arguments[key] = get_dict(datablock,
bound.arguments[key])
result = func(*bound.args, **bound.kwargs)

set_dict(datablock, return_key, result)

return datablock
return wrapper
return pipeline_decorator
129 changes: 127 additions & 2 deletions agrifoodpy/pipeline/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from agrifoodpy.pipeline.pipeline import Pipeline, standalone
from agrifoodpy.pipeline import Pipeline, standalone
import pytest

def test_init():
pipeline = Pipeline()
Expand Down Expand Up @@ -202,4 +203,128 @@ def pipeline_decorated(x, out_key, datablock=None):

pipeline.add_node(pipeline_decorated, params={'x': 'x', 'out_key': 'result'})
pipeline.run()
assert pipeline.datablock['result'] == 15
assert pipeline.datablock['result'] == 15

def test_pipeline_node_decorator():

from agrifoodpy.pipeline.pipeline import Pipeline, pipeline_node

test_datablock_single = {'value1': 5, 'value2': 10}
test_pipeline_single = Pipeline(test_datablock_single)

# Test decorated function with single input key and no return key
@pipeline_node('x')
def double_numbers(x):
return x * 2

test_pipeline_single.add_node(
double_numbers,
params={'x': 'value1'}
)

test_pipeline_single.run()
assert double_numbers(test_datablock_single['value1']) == 10
assert double_numbers.__name__ in test_pipeline_single.datablock
assert test_pipeline_single.datablock[double_numbers.__name__] == 10

# Test decorated function with single input key and unregistered key
@pipeline_node('value')
def scale_numbers(value, factor=3):
return value * factor

test_datablock_mixed = {'value': 5}

test_pipeline_mixed = Pipeline(test_datablock_mixed)
test_pipeline_mixed.add_node(
scale_numbers,
params={'value': 'value', 'factor': 4}
)
test_pipeline_mixed.run()
assert scale_numbers(test_datablock_mixed['value'], factor=4) == 20
assert scale_numbers.__name__ in test_pipeline_mixed.datablock
assert test_pipeline_mixed.datablock[scale_numbers.__name__] == 20

# Test decorated function with multiple input keys and no return key
test_datablock_multiple = {'value1': 5, 'value2': 10}
test_pipeline_multiple = Pipeline(test_datablock_multiple)

@pipeline_node(['x', 'y'])
def sum_numbers(x, y):
return x + y

test_pipeline_multiple.add_node(
sum_numbers,
params={'x': 'value1', 'y': 'value2'}
)

test_pipeline_multiple.run()
assert sum_numbers(
test_datablock_multiple['value1'],
test_datablock_multiple['value2']) == 15
assert sum_numbers.__name__ in test_pipeline_multiple.datablock
assert test_pipeline_multiple.datablock[sum_numbers.__name__] == 15

# Test decorated function with multiple input keys and return key
test_datablock_with_return = {'value1': 5, 'value2': 10}
test_pipeline_with_return = Pipeline(test_datablock_with_return)
return_key = "result"

@pipeline_node(['x', 'y'])
def subtract_numbers(x, y):
return x - y

test_pipeline_with_return.add_node(
subtract_numbers,
params={'x': 'value1', 'y': 'value2', "return_key": return_key}
)

test_pipeline_with_return.run()
assert subtract_numbers(
test_datablock_with_return['value1'],
test_datablock_with_return['value2']) == -5
assert return_key in test_pipeline_with_return.datablock
assert test_pipeline_with_return.datablock[return_key] == -5

#test decorated function with external function
test_datablock_external = {'value1': [1, 2, 3]}
test_pipeline_external = Pipeline(test_datablock_external)

import numpy as np

test_pipeline_external.add_node(
pipeline_node(input_keys="a")(np.mean),
params={'a': 'value1', 'return_key': "mean_result"}
)

test_pipeline_external.run()
assert np.mean(test_datablock_external['value1']) == 2
assert "mean_result" in test_pipeline_external.datablock
assert test_pipeline_external.datablock["mean_result"] == 2

# Test decorated function with no input keys
test_pipeline_no_input = Pipeline()

@pipeline_node([])
def return_constant():
return 42

test_pipeline_no_input.add_node(
return_constant
)

test_pipeline_no_input.run()
assert return_constant() == 42
assert return_constant.__name__ in test_pipeline_no_input.datablock
assert test_pipeline_no_input.datablock[return_constant.__name__] == 42

# Test decorated function with reserved parameter names
with pytest.raises(ValueError, match="reserved parameter names.*datablock"):
@pipeline_node(['x'])
def reserved_param_node(x, datablock=None):
pass

# Test decorated function with unknown input keys
with pytest.raises(ValueError, match="input_keys.*not found in parameters"):
@pipeline_node(['wrong_key'])
def unknown_input_node(right_key):
pass
Loading